diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/MMaDA/.cursor/rules/python-env.mdc b/MMaDA/.cursor/rules/python-env.mdc new file mode 100644 index 0000000000000000000000000000000000000000..c1ff6bbc9d8b0c60d891753ae6b69b02dffa02de --- /dev/null +++ b/MMaDA/.cursor/rules/python-env.mdc @@ -0,0 +1,4 @@ +--- +alwaysApply: true +--- +When running python script, use conda env `mmada`. \ No newline at end of file diff --git a/MMaDA/.gitignore b/MMaDA/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4bf4b61a759ad4f6794799e6596ce371df3c8858 --- /dev/null +++ b/MMaDA/.gitignore @@ -0,0 +1,2 @@ +exp +wandb \ No newline at end of file diff --git a/MMaDA/AIDAS-Omni-Modal-Diffusion/app.py b/MMaDA/AIDAS-Omni-Modal-Diffusion/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d27947cc49b2a23d6ea5de59999022f729c6c9d8 --- /dev/null +++ b/MMaDA/AIDAS-Omni-Modal-Diffusion/app.py @@ -0,0 +1,16 @@ +import gradio as gr +import spaces +import torch + +zero = torch.Tensor([0]).cuda() +print(zero.device) # should print 'cpu' until GPU context is enabled + + +@spaces.GPU +def greet(n): + print(zero.device) # now this should print 'cuda:0' + return f"Hello {zero + n} Tensor" + + +demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text()) +demo.launch() diff --git a/MMaDA/LICENSE b/MMaDA/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9ccb18421170127e4d4eab6963699a798d3b4ad3 --- /dev/null +++ b/MMaDA/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Ling Yang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MMaDA/README.md b/MMaDA/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d837c46be3c49c5e478d8b8c6e9fd37322c066d3 --- /dev/null +++ b/MMaDA/README.md @@ -0,0 +1,209 @@ +
+
+ +

Multimodal Large Diffusion Language Models

+ +

+ + MMaDA Paper on arXiv + + + MMaDA on Hugging Face + + + MMaDA on Hugging Face + + + MMaDA on Hugging Face + + + Wechat Group Link + + +

+ + +## 🌌 Introduction +MMaDA is a new family of **multimodal diffusion foundation models** designed to achieve superior performance across diverse domains such as textual reasoning, multimodal understanding, and text-to-image generation. MMaDA is distinguished by three key innovations: +1. MMaDA adopts a **unified diffusion architecture** with a shared probabilistic formulation and a modality-agnostic design, eliminating the need for modality-specific components. +2. MMaDA introduces a **mixed long chain-of-thought (CoT) fine-tuning** strategy that curates a unified CoT format across modalities. +3. MMaDA adopts a unified policy-gradient-based RL algorithm, which we call **UniGRPO**, tailored for diffusion foundation models. Utilizing diversified reward modeling, **UniGRPO** unifies post-training across both reasoning and generation tasks, ensuring consistent performance improvements. + +
+ MMaDA decoding demo +

+ MMaDA's decoding demo. This video showcases how a diffusion foundation model generates text and image.
+ The "Text Generation" part uses a semi-autoregressive sampling method, while the "Multimodal Generation" part adopts non-autoregressive diffusion denoising. +

+
+ + + + + + + + + +## šŸ“° Latest Updates +* **[2025-06-02]** We open source our **MMaDA-8B-MixCoT** at [Huggingface](https://huggingface.co/Gen-Verse/MMaDA-8B-MixCoT). +* **[2025-05-24]** We add support for MPS inference, tested on M4. +* **[2025-05-22]** We release the inference and training code of MMaDA for text generation, multimodal generation and image generation. +* **[2025-05-22]** We open source our **MMaDA-8B-Base** at [Huggingface](https://huggingface.co/Gen-Verse/MMaDA-8B-Base). **MMaDA-8B-MixCoT** and **MMaDA-8B-Max** will be released in the near future. +* **[2025-05-22]** We release our [research paper](https://arxiv.org/abs/2505.15809) and [demo](https://huggingface.co/spaces/Gen-Verse/MMaDA) for the first unified multimodal diffusion model: MMaDA. + + +## 🧬 MMaDA Series Overview + +MMaDA includes a series of checkpoints reflecting different training stages: +1. **[MMaDA-8B-Base](https://huggingface.co/Gen-Verse/MMaDA-8B-Base)**: After pretraining and instruction tuning. Capable of basic text generation, image generation, image captioning and **thinking ablities**. +2. **[MMaDA-8B-MixCoT](https://huggingface.co/Gen-Verse/MMaDA-8B-MixCoT)**: After mixed long chain-of-thought (CoT) fine-tuning. Capable of **complex** textual, multimodal and image generation reasoning. +3. **MMaDA-8B-Max (coming soon)**: After UniGRPO reinforment learning. Excels at complex reasoning and awesome visual generation. Will be released in the future. +
+ +

Overview of MMaDA's capablities.

+
+ + + + +## āœ… TODO +- [x] Release [MMaDA-8B-MixCoT](https://huggingface.co/Gen-Verse/MMaDA-8B-MixCoT) +- [ ] Release MMaDA-8B-Max and OpenRLHF-based UniGRPO training code. + +## āš™ļø Quick Start +First, set up the enviroment: +``` +pip install -r requirements.txt +``` +Launch local Gradio demo: +``` +python app.py +``` +Or try it online via our [Huggingface Demo](https://huggingface.co/spaces/Gen-Verse/MMaDA). + +## šŸš€ Inference +For batch-level inference, we provide our inference scripts here. +### 1. Text Generation +For text generation, we follow LLaDA's configuration and generation script. Simple run: +```bash +python generate.py +``` + +### 2. MultiModal Generation +For multimodal generation and text-to-image generation, first login your wandb account: +``` +wandb login +``` +Inference demo for MultiModal Generation and you can view the results on wandb: +``` +python3 inference_mmu.py config=configs/mmada_demo.yaml mmu_image_root=./mmu_validation question='Please describe this image in detail.' +``` + +### 3. Text-to-Image Genertion +For multimodal generation and text-to-image generation, first login your wandb account: +``` +wandb login +``` +Inference demo for Text-to-Image Genertion and you can view the results on wandb: +``` +python3 inference_t2i.py config=configs/mmada_demo.yaml batch_size=1 validation_prompts_file=validation_prompts/text2image_prompts.txt guidance_scale=3.5 generation_timesteps=15 +mode='t2i' +``` + +## šŸ”§ Training +**Update your training data path in `configs/xx.yaml`.** + +### Stage 0. Prepare your accelerate configs +Please first prepare your accelerate configs. You can simple run +``` +accelerate config +``` + +Or use our provided configs in `accelerate_configs`: +``` +ā”œā”€ā”€ accelerate_configs/ +| ā”œā”€ā”€ 1_gpu.yaml +| └── 8_node_8_gpus_deepspeed_zero2.yaml (for 8 * 8 gpus) +``` + +### Stage 1.1: Pre-training on ImageNet +First we use LLaDA-8B-Instruct to initialize our model, and train on ImageNet for basic visual capbalities. +``` +accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada.py config=configs/mmada_pretraining_stage1_llada_instruct.yaml +``` + +### Stage 1.2 Pre-training on Image-Text Dataset +Then we replace the ImageNet dataset in Stage 1.1 with Image-Text Dataset. Please change the pretrained model path in `mmada_pretraining_stage2_llada_instruct.yaml` with your checkpoint in Stage 1.1 +``` +accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage2.py config=configs/mmada_pretraining_stage2_llada_instruct.yaml +``` + +### Stage 1.3 Pre-training on Text Instruction following +In this stage, we begin training on text instruction following and include corresponding validations. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 1.2 +``` +accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage3.py config=configs/mmada_pretraining_stage3_llada_instruct.yaml +``` + +### Stage 2.1 Mix-CoT Training (Text Only) +In this stage, we begin our Mix-CoT finetuning with text reasoning first, along with improved image quality. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 1.3 and prepare your CoT data. +``` +accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage_cot_sft.py config=configs/mmada_pretraining_stage3_llada_instruct_512_cot.yaml +``` + +### Stage 2.2 Mix-CoT Training (with MultiModal Reasoning) +In this stage, we include multimodal reasoning, along with improved image quality. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 2.1 and prepare your CoT data. +``` +accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage4.py config=configs/mmada_pretraining_stage4_llada_instruct.yaml +``` + +### Stage 3 UniGRPO RL +[Will be released once we finished our code transition to OpenRLHF] + + +## šŸ“– Citation +``` +@article{yang2025mmada, + title={MMaDA: Multimodal Large Diffusion Language Models}, + author={Yang, Ling and Tian, Ye and Li, Bowen and Zhang, Xinchen and Shen, Ke and Tong, Yunhai and Wang, Mengdi}, + journal={arXiv preprint arXiv:2505.15809}, + year={2025} +} +``` + +## šŸ¤ Acknowledgments +This work is heavily based on [Show-o](https://github.com/showlab/Show-o), [LLaDA](https://github.com/ML-GSAI/LLaDA), [maskgit](https://github.com/google-research/maskgit), [transformers](https://github.com/huggingface/transformers), [accelerate](https://github.com/huggingface/accelerate) and [webdataset](https://github.com/webdataset/webdataset). Thanks to all the authors for their great work. + +## šŸ’¬ Discussion and Collaboration + +Welcome to discuss and collaborate with us for continuously improving MMaDA. If you have any bad cases, please kindly share them in the [Issue](https://github.com/Gen-Verse/MMaDA/issues/4#issue-3083196081). + +Also, you can reach us with this WeChat QR code! +

+ +

+ diff --git a/MMaDA/accelerate_configs/1_gpu.yaml b/MMaDA/accelerate_configs/1_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f4525c78804745e0805a5df05fee00197163a09 --- /dev/null +++ b/MMaDA/accelerate_configs/1_gpu.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +distributed_type: 'NO' +downcast_bf16: 'no' +gpu_ids: '0' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero2.yaml b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30b2d800f85d86f33ec7d3385723a88885bffa36 --- /dev/null +++ b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero3.yaml b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e16c6bd6051caadf07df618bccabf55fdfe425b --- /dev/null +++ b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero3.yaml @@ -0,0 +1,24 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 2 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 + zero_optimization: + overlap_comm: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea5b9f219008290ac4c62ded876e7e5bb6729880 --- /dev/null +++ b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml @@ -0,0 +1,24 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 2 + zero_optimization: + overlap_comm: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas.yaml b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9029f3a85ad24f5e51b9188f13af001335a42c03 --- /dev/null +++ b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 2 + zero_optimization: + overlap_comm: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +main_process_ip: 172.51.80.134 +main_training_function: main +num_machines: 2 +num_processes: 16 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas2.yaml b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..57c02ff4043d86350223992514199024f3a409db --- /dev/null +++ b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas2.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 2 + zero_optimization: + overlap_comm: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +main_process_ip: 172.51.80.136 +main_training_function: main +num_machines: 4 +num_processes: 32 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48df01331374b048882fca4dfe23484303c85f43 --- /dev/null +++ b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml @@ -0,0 +1,26 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 4 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 2 + zero_optimization: + overlap_comm: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +main_process_ip: 172.51.64.134 +main_training_function: main +num_machines: 2 +num_processes: 16 +machine_rank: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml b/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..707453ded5101d525d7810574c128f661c59227c --- /dev/null +++ b/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 4 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 2 + zero_optimization: + overlap_comm: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +main_process_ip: 172.51.64.130 +main_training_function: main +num_machines: 3 +num_processes: 24 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2.yaml b/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0dd7a4cbbbd7902ce3abc1c2cfac1b4649d7fbd8 --- /dev/null +++ b/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 4 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +main_training_function: main +mixed_precision: bf16 +num_machines: 4 +num_processes: 32 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml b/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13b4a8d0f00d68de8373deb98974287431292e7d --- /dev/null +++ b/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none #cpu + offload_param_device: none #cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 2 + zero_optimization: + overlap_comm: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: true +main_process_ip: 172.51.133.6 +main_training_function: main +num_machines: 4 +num_processes: 32 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/accelerate_configs/8_node_8_gpus_deepspeed_zero2.yaml b/MMaDA/accelerate_configs/8_node_8_gpus_deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9df6e71c045830d81d6174c9a5ae3efbb3df126 --- /dev/null +++ b/MMaDA/accelerate_configs/8_node_8_gpus_deepspeed_zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +main_training_function: main +mixed_precision: bf16 +num_machines: 8 +num_processes: 64 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/MMaDA/app.py b/MMaDA/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbb3aa3c14863e1292379a4ea34ad93e2b92cb6 --- /dev/null +++ b/MMaDA/app.py @@ -0,0 +1,894 @@ +import gradio as gr +import torch +import numpy as np +import torch.nn.functional as F +from transformers import AutoTokenizer +from torchvision import transforms +from models import MAGVITv2, get_mask_schedule, MMadaModelLM +from training.prompting_utils import UniversalPrompting +from PIL import Image + +def image_transform(image, resolution=256, normalize=True): + image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image) + image = transforms.CenterCrop((resolution, resolution))(image) + image = transforms.ToTensor()(image) + if normalize: + image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) + return image + +def add_gumbel_noise(logits, temperature): + """ + Adds Gumbel noise to logits for stochastic sampling. + Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1). + This version is more numerically stable than a version involving exp() and division. + """ + if abs(temperature) < 1e-9: # Effectively zero temperature + return logits + # Ensure logits are float64 for precision with noise, as suggested by user context + if DEVICE == "mps": + logits = logits.to(torch.float32) + else: + logits = logits.to(torch.float64) + # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1) + # Add small epsilon for numerical stability inside logs + if DEVICE == "mps": + noise = torch.rand_like(logits, dtype=torch.float32) + else: + noise = torch.rand_like(logits, dtype=torch.float64) + standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + return logits + temperature * standard_gumbel_noise + +def get_num_transfer_tokens(mask_index, steps): + mask_num = mask_index.sum(dim=1, keepdim=True) + # Ensure steps is at least 1 to avoid division by zero if mask_num is also 0 (though sum should be >=0) + steps = max(1, int(steps)) # Ensure steps is a positive integer + base = mask_num // steps + remainder = mask_num % steps + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base + for i in range(mask_num.size(0)): # Iterate over batch + if remainder[i] > 0 : # Ensure remainder is positive before indexing + num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int + return num_transfer_tokens + +MODEL = None +TOKENIZER = None +DEVICE = ( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" +) +MASK_ID = None +uni_prompting = None +VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE) + +DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-Base" # Default +CURRENT_MODEL_PATH = None + +MODEL_CHOICES = [ + "MMaDA-8B-Base", + "MMaDA-8B-MixCoT (coming soon)", + "MMaDA-8B-Max (coming soon)" +] +MODEL_ACTUAL_PATHS = { + "MMaDA-8B-Base": DEFAULT_MODEL_PATH, +} + +def clear_outputs_action(): + return None, None + +def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status): + global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting + + if MODEL is not None and CURRENT_MODEL_PATH == model_path_to_load: + return f"Model '{model_display_name_for_status}' from '{model_path_to_load}' is already loaded. MASK_ID: {MASK_ID}" + + CURRENT_MODEL_PATH = model_path_to_load + + status_msg_parts = [f"Loading '{model_display_name_for_status}'..."] + try: + TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True) + status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.") + + MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval() + status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.") + + uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True) + + if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None: + MASK_ID = TOKENIZER.mask_token_id + status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.") + else: + MASK_ID = 126336 + status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.") + + if TOKENIZER.pad_token_id is None: + if TOKENIZER.eos_token_id is not None: + TOKENIZER.pad_token_id = TOKENIZER.eos_token_id + TOKENIZER.pad_token = TOKENIZER.eos_token + status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).") + else: + status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.") + + if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization + status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.") + + TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}" + + return " ".join(status_msg_parts) + except Exception as e: + MODEL = None + TOKENIZER = None + MASK_ID = None + CURRENT_MODEL_PATH = None + return f"Error loading model '{model_display_name_for_status}': {str(e)}" + +def handle_model_selection_change(selected_model_name_ui): + if "coming soon" in selected_model_name_ui.lower(): + global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH + MODEL = None + TOKENIZER = None + MASK_ID = None + CURRENT_MODEL_PATH = None + return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'." + + actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui) + if not actual_path: + return f"Path for '{selected_model_name_ui}' is not defined. Cannot load." + + return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui) + + +def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask): + if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0: + return [("Error in sequence data for visualization.", "ERROR")] + # only answer part + current_x_ids_batch = current_x_ids_batch[:, prompt_len:] + seq_ids = current_x_ids_batch[0].tolist() + eos_token_id = tk.eos_token_id # Get EOS token ID + + # Stage 1: Build initial list of tuples with (token_str, label, token_id_int) + # This helps in identifying EOS tokens later without re-checking the type. + intermediate_tuples = [] + for j, token_id_int in enumerate(seq_ids): + try: + token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False) + except Exception: # Handle cases where a token ID might be problematic (e.g. with mock) + token_str = f"[ID:{token_id_int}]" + + label = "ERROR" + if token_id_int == current_mask_id: + token_str = "[MASK]" + label = "MASK" + else: + label = "GEN" + intermediate_tuples.append((token_str, label, token_id_int)) + + return intermediate_tuples + +@torch.no_grad() +def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"): + global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting + + if MODEL is None or TOKENIZER is None or MASK_ID is None: + yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded." + return + steps = int(steps) + guidance_scale = float(guidance_scale) + + image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID + prompt_text = [prompt_text] + input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen') + + if guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + mask_schedule = get_mask_schedule(mask_schedule) + blank_image = Image.new("RGB", (512, 512), (255, 255, 255)) + yield blank_image, "Starting generation..." + for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise( + input_ids = input_ids, + uncond_input_ids = uncond_input_ids, + attention_mask = attention_mask, + uncond_attention_mask = uncond_attention_mask, + temperature=1.0, + timesteps = steps, + guidance_scale = guidance_scale, + noise_schedule = mask_schedule, + noise_type = "mask", + seq_len = 1024, + vq_model = VQ_MODEL, + uni_prompting=uni_prompting): + yield image_step, status_msg_step + + + + +@torch.no_grad() +def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature, + cfg_scale, remasking_strategy, thinking_mode_lm): + global MODEL, TOKENIZER, MASK_ID, DEVICE + print(f"thinking_mode_lm: {thinking_mode_lm}") + if MODEL is None or TOKENIZER is None or MASK_ID is None: + yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded." + return + + steps = int(steps) + gen_length = int(gen_length) + block_length = int(block_length) + + if thinking_mode_lm: + prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\n" + prompt_text + + try: + m = [{"role": "user", "content": prompt_text}] + processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False) + except Exception as e: + yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}" + processed_prompt_text = prompt_text + try: + if TOKENIZER.pad_token_id is None: + if TOKENIZER.eos_token_id is not None: + TOKENIZER.pad_token_id = TOKENIZER.eos_token_id + else: # Should have been caught by load_model, but double check + yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer." + return + + input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE) + raw_prompt_attention_mask = None + + except Exception as e: + yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}" + return + + + + batch_size = input_ids.shape[0] + prompt_len = input_ids.shape[1] + + x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE) + x[:, :prompt_len] = input_ids.clone() + + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks" + + if gen_length == 0: + final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True) + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else "" + return + + if block_length <= 0 or gen_length % block_length != 0 : + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \ + f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0." + return + num_blocks = gen_length // block_length + + if steps <=0 or steps % num_blocks != 0: + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \ + f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}" + return + steps_per_block = steps // num_blocks + + for num_block_iter in range(num_blocks): + current_block_start_idx_in_x = prompt_len + num_block_iter * block_length + current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length + + block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool) + block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \ + (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID) + + num_transfer_tokens_for_this_block = get_num_transfer_tokens( + block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x], + steps_per_block + ) + + for i_step_in_block in range(steps_per_block): + mask_index_global = (x == MASK_ID) + + if cfg_scale > 0.: + un_x = x.clone() + # For unconditional pass, mask out the original prompt tokens that are not padding + # raw_prompt_attention_mask is (B, prompt_len) + prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are + un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID + + x_cfg_input = torch.cat([x, un_x], dim=0) + # Pass attention_mask for CFG if model expects it, covering both parts + # For simplicity, not passing explicit attention_mask here; relies on model's internal handling. + model_output = MODEL(x_cfg_input) + logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0) + logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond) + else: + # Not passing explicit attention_mask here; relies on model's internal handling. + model_output = MODEL(x) + logits = model_output.logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1) + + if remasking_strategy == 'low_confidence': + if DEVICE == "mps": + probs = F.softmax(logits.to(torch.float32), dim=-1) + else: + probs = F.softmax(logits.to(torch.float64), dim=-1) + x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1) + elif remasking_strategy == 'random': + if DEVICE == "mps": + x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float32) + else: + x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64) + else: + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'" + return + + confidence_for_selection = torch.full_like(x0_probs, -torch.inf) + candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current + confidence_for_selection = torch.where( + candidate_positions_for_unmasking, + x0_probs, + -torch.inf + ) + + x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x) + + transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool) + num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block] + + for j_batch_idx in range(batch_size): + k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(), + candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large + + if k_val > 0: + # Ensure confidence_for_selection[j_batch_idx] is 1D for topk + conf_slice = confidence_for_selection[j_batch_idx] + if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs + + # Check if there are enough valid (non -inf) confidences + valid_conf_count = (conf_slice > -torch.inf).sum().item() + actual_k = min(k_val, valid_conf_count) + + if actual_k > 0: + _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k) + transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True + + x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool] + + current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1 + total_overall_steps = num_blocks * steps_per_block + status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})" + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg + + final_generated_ids = x[:, prompt_len:] + final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True) + + final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else "" + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str + +@torch.no_grad() +def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature, + cfg_scale, remasking_strategy, thinking_mode_mmu): + global MODEL, TOKENIZER, MASK_ID, DEVICE + + if MODEL is None or TOKENIZER is None or MASK_ID is None: + yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded." + return + + steps = int(steps) + gen_length = int(gen_length) + block_length = int(block_length) + + if thinking_mode_mmu: + prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\n" + prompt_text + + try: + m = [{"role": "user", "content": prompt_text}] + processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False) + except Exception as e: + yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}" + processed_prompt_text = prompt_text + + image_vq_ids_tensor = None + if uploaded_image_pil is not None: + try: + + image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE) + image = image.unsqueeze(0) + image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349 + except Exception as e: + yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}" + return + + + try: + if TOKENIZER.pad_token_id is None: + if TOKENIZER.eos_token_id is not None: + TOKENIZER.pad_token_id = TOKENIZER.eos_token_id + else: + yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer." + return + + input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE) + raw_prompt_attention_mask = None + if image_vq_ids_tensor is not None: + if image_vq_ids_tensor.ndim == 1: + image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE), + (torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE), + image_vq_ids_tensor, + (torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE), + input_ids + ], dim=1).long() + + else: + input_ids = input_ids + + + except Exception as e: + yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}" + return + + + + batch_size = input_ids.shape[0] + prompt_len = input_ids.shape[1] + + x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE) + x[:, :prompt_len] = input_ids.clone() + + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks" + + if gen_length == 0: + final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True) + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else "" + return + + if block_length <= 0 or gen_length % block_length != 0 : + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \ + f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0." + return + num_blocks = gen_length // block_length + + if steps <=0 or steps % num_blocks != 0: + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \ + f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}" + return + steps_per_block = steps // num_blocks + + for num_block_iter in range(num_blocks): + current_block_start_idx_in_x = prompt_len + num_block_iter * block_length + current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length + + block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool) + block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \ + (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID) + + num_transfer_tokens_for_this_block = get_num_transfer_tokens( + block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x], + steps_per_block + ) + + for i_step_in_block in range(steps_per_block): + mask_index_global = (x == MASK_ID) + + if cfg_scale > 0.: + un_x = x.clone() + # For unconditional pass, mask out the original prompt tokens that are not padding + # raw_prompt_attention_mask is (B, prompt_len) + prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are + un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID + + x_cfg_input = torch.cat([x, un_x], dim=0) + # Pass attention_mask for CFG if model expects it, covering both parts + # For simplicity, not passing explicit attention_mask here; relies on model's internal handling. + model_output = MODEL(x_cfg_input) + logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0) + logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond) + else: + # Not passing explicit attention_mask here; relies on model's internal handling. + model_output = MODEL(x) + logits = model_output.logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1) + + if remasking_strategy == 'low_confidence': + if DEVICE == "mps": + probs = F.softmax(logits.to(torch.float32), dim=-1) + else: + probs = F.softmax(logits.to(torch.float64), dim=-1) + x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1) + elif remasking_strategy == 'random': + if DEVICE == "mps": + x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float32) + else: + x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64) + else: + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'" + return + + confidence_for_selection = torch.full_like(x0_probs, -torch.inf) + candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current + confidence_for_selection = torch.where( + candidate_positions_for_unmasking, + x0_probs, + -torch.inf + ) + + x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x) + + transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool) + num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block] + + for j_batch_idx in range(batch_size): + k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(), + candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large + + if k_val > 0: + # Ensure confidence_for_selection[j_batch_idx] is 1D for topk + conf_slice = confidence_for_selection[j_batch_idx] + if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs + + # Check if there are enough valid (non -inf) confidences + valid_conf_count = (conf_slice > -torch.inf).sum().item() + actual_k = min(k_val, valid_conf_count) + + if actual_k > 0: + _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k) + transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True + + x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool] + + current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1 + total_overall_steps = num_blocks * steps_per_block + status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})" + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg + + final_generated_ids = x[:, prompt_len:] + final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True) + + final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else "" + yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str + + +css_styles = """ +.gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;} +.gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;} +.gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;} + +.highlighted-text span{ + padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6; +} + +footer{display:none !important} + +#live-update-scrollable-box { + max-height: 800px; /* ę‚ØåÆä»„ę ¹ę®éœ€č¦č°ƒę•“čæ™äøŖęœ€å¤§é«˜åŗ¦ļ¼Œä¾‹å¦‚ '300px', '50vh' ē­‰ */ + overflow-y: auto !important; /* 当内容超出 max-height ę—¶ę˜¾ē¤ŗåž‚ē›“ę»šåŠØę” */ + display: block; /* ē”®äæå…ƒē“ ę˜Æå—ēŗ§å…ƒē“ ļ¼Œä»„ä¾æ max-height ē”Ÿę•ˆ */ + +} +#think_btn { + background-color: #f3f4f6 !important; + border: 1px solid #d0d0d0 !important; + color: #111827 !important; + font-size: 16px !important; + font-weight: bold !important; +} +#think_btn:hover { + background-color: #e0e0e0 !important; + border: 1px solid #c0c0c0 !important; + color: #222 !important; +} +#think_btn:active { + background-color: #2563eb !important; + border: 1px solid #b0b0b0 !important; + color: white !important; +} +""" + + +# thinking_mode_t2i = gr.State(False) +def toggle_thinking_mode_lm(current_thinking_mode): + # print(f"current_thinking_mode: {current_thinking_mode}") + new_state = not current_thinking_mode + new_label = "Thinking Mode āœ…" if new_state else "Thinking Mode āŒ" + return new_state, gr.update(value=new_label) + +def toggle_thinking_mode_mmu(current_thinking_mode): + new_state = not current_thinking_mode + new_label = "Thinking Mode āœ…" if new_state else "Thinking Mode āŒ" + return new_state, gr.update(value=new_label) + + +color_map_config = { + "MASK": "lightgrey", + "GEN": "#DCABFA", +} + +theme = gr.themes.Ocean( + primary_hue="fuchsia", +) +with gr.Blocks(css=css_styles, theme=theme) as demo: +# with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo: +# with gr.Blocks() as demo: + thinking_mode_lm = gr.State(False) + thinking_mode_mmu = gr.State(False) + gr.Markdown("

MMaDA: Multimodal Large Diffusion Language Models

") + gr.Markdown("MMaDA is a novel class of multimodal diffusion foundation models designed to achieve superior performance across diverse domains such as textual reasoning, multimodal understanding, and text-to-image generation") + gr.Markdown("Github: [Gen-Verse/MMaDA](https://github.com/Gen-Verse/MMaDA)") + gr.Markdown("Paper: [MMaDA: Multimodal Large Diffusion Language Models]()") + gr.Markdown("### Select Model") + with gr.Row(): + model_select_radio = gr.Radio( + label="Select Text Generation Model", + choices=MODEL_CHOICES, + value=MODEL_CHOICES[0] + ) + model_load_status_box = gr.Textbox( + label="Model Load Status", + interactive=False, + lines=3, + max_lines=5 + ) + + gr.Markdown("## Part 1. Text Generation") + with gr.Row(): + with gr.Column(scale=2): + prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?") + think_button_lm = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn") + with gr.Accordion("Generation Parameters", open=True): + with gr.Row(): + gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.") + steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).") + with gr.Row(): + block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.") + remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy") + with gr.Row(): + cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.") + temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.") + + + with gr.Row(): + run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3) + clear_button_ui_lm = gr.Button("Clear Outputs", scale=1) + + with gr.Column(scale=3): + # gr.Markdown("## Live Generation Process") + output_visualization_box_lm = gr.HighlightedText( + label="Live Generation Process", + show_legend=True, + color_map=color_map_config, + combine_adjacent=False, + interactive=False, + elem_id="live-update-scrollable-box", + ) + # gr.Markdown("## Final Generated Text") + output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True) + + + + gr.Examples( + examples=[ + ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"], + ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"] + ], + inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm], + outputs=[output_visualization_box_lm, output_final_text_box_lm], + fn=generate_viz_wrapper_lm, + ) + + gr.Markdown("---") + gr.Markdown("## Part 2. Multimodal Understanding") + with gr.Row(): + with gr.Column(scale=2): + prompt_input_box_mmu = gr.Textbox( + label="Enter your prompt:", + lines=3, + value="Please describe this image in detail." + ) + think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn") + with gr.Accordion("Generation Parameters", open=True): + with gr.Row(): + gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.") + steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).") + with gr.Row(): + block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.") + remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy") + with gr.Row(): + cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.") + temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.") + + with gr.Row(): + image_upload_box = gr.Image(type="pil", label="Upload Image") + + with gr.Row(): + run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3) + clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1) + + with gr.Column(scale=3): + gr.Markdown("## Live Generation Process") + output_visualization_box_mmu = gr.HighlightedText( + label="Token Sequence (Live Update)", + show_legend=True, + color_map=color_map_config, + combine_adjacent=False, + interactive=False, + elem_id="live-update-scrollable-box", + ) + gr.Markdown("## Final Generated Text") + output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True) + + + gr.Examples( + examples=[ + [ + "mmu_validation_2/sunflower.jpg", + "Please describe this image in detail.", + 256, + 512, + 128, + 1, + 0, + "low_confidence" + ], + [ + "mmu_validation_2/woman.jpg", + "Please describe this image in detail.", + 256, + 512, + 128, + 1, + 0, + "low_confidence" + ] + ], + inputs=[ + image_upload_box, + prompt_input_box_mmu, + steps_slider_mmu, + gen_length_slider_mmu, + block_length_slider_mmu, + temperature_slider_mmu, + cfg_scale_slider_mmu, + remasking_dropdown_mmu + ], + outputs=[output_visualization_box_mmu, output_final_text_box_mmu], + fn=generate_viz_wrapper, + ) + + gr.Markdown("---") + gr.Markdown("## Part 3. Text-to-Image Generation") + with gr.Row(): + with gr.Column(scale=2): + prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.") + + with gr.Accordion("Generation Parameters", open=True): + with gr.Row(): + steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).") + guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale", info="Classifier-Free Guidance. 0 disables it.") + + + with gr.Row(): + scheduler_radio_t2i = gr.Radio( + choices=["cosine", "sigmoid", "linear"], + value="cosine", + label="Scheduler", + ) + + with gr.Row(): + run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3) + clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1) + + + with gr.Column(scale=3): + # gr.Markdown("## Live Generation Process") + output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil") + output_status_t2i = gr.Textbox(label="Generation Status", interactive=False) + + gr.Examples( + examples=[ + ["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"], + ["A beautiful sunset over a calm ocean, with a few clouds in the sky.", 15, 3.5, "cosine"] + ], + inputs=[prompt_input_box_t2i, steps_slider_t2i, guidance_scale_slider_t2i, scheduler_radio_t2i], + outputs=[output_image_t2i, output_status_t2i], + fn=generate_viz_wrapper_t2i, + ) + + run_button_ui_t2i.click( + fn=generate_viz_wrapper_t2i, + inputs=[ + prompt_input_box_t2i, + steps_slider_t2i, + guidance_scale_slider_t2i, + scheduler_radio_t2i + ], + outputs=[output_image_t2i, output_status_t2i] + ) + + clear_button_ui_t2i.click( + fn=lambda: (None, ""), + inputs=None, + outputs=[output_image_t2i, output_status_t2i], + queue=False + ) + + think_button_lm.click( + fn=toggle_thinking_mode_lm, + inputs=[thinking_mode_lm], + outputs=[thinking_mode_lm, think_button_lm] + ) + + think_button_mmu.click( + fn=toggle_thinking_mode_mmu, + inputs=[thinking_mode_mmu], + outputs=[thinking_mode_mmu, think_button_mmu] + ) + + + + def initialize_default_model(): + default_model = "MMaDA-8B-Base" + result = handle_model_selection_change(default_model) + return default_model, result + + demo.load( + fn=initialize_default_model, + inputs=None, + outputs=[model_select_radio, model_load_status_box], + queue=True + ) + + def clear_outputs(): + return None, None, None # Clear image, visualization, and final text + + clear_button_ui_lm.click( + fn=clear_outputs, + inputs=None, + outputs=[image_upload_box, output_visualization_box_lm, output_final_text_box_lm], + queue=False + ) + clear_button_ui_mmu.click( + fn=clear_outputs, + inputs=None, + outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu], + queue=False + ) + + run_button_ui_lm.click( + fn=generate_viz_wrapper_lm, + inputs=[ + prompt_input_box_lm, + steps_slider_lm, + gen_length_slider_lm, + block_length_slider_lm, + temperature_slider_lm, + cfg_scale_slider_lm, + remasking_dropdown_lm, + thinking_mode_lm + ], + outputs=[output_visualization_box_lm, output_final_text_box_lm] + ) + + run_button_ui_mmu.click( + fn=generate_viz_wrapper, + inputs=[ + image_upload_box, + prompt_input_box_mmu, + steps_slider_mmu, + gen_length_slider_mmu, + block_length_slider_mmu, + temperature_slider_mmu, + cfg_scale_slider_mmu, + remasking_dropdown_mmu, + thinking_mode_mmu + ], + outputs=[output_visualization_box_mmu, output_final_text_box_mmu] + ) + + +if __name__ == "__main__": + print(f"Starting Gradio App. Attempting to use device: {DEVICE}") + demo.launch(share=True) \ No newline at end of file diff --git a/MMaDA/check_lr.py b/MMaDA/check_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..13f3efdf2ee9416eb41db187ac5c94759b59c985 --- /dev/null +++ b/MMaDA/check_lr.py @@ -0,0 +1,27 @@ +import torch +from torch.optim import AdamW + +from models.lr_schedulers import get_scheduler + +MAX_TRAINING_STEPS = 100 +WARMUP_STEPS = 80 +INITIAL_LR = 5e-5 +SCHEDULER_TYPE = "cosine" # "linear", "cosine" +# --------------------------------------------- + +dummy_model = torch.nn.Linear(1, 1) +dummy_optimizer = AdamW(dummy_model.parameters(), lr=INITIAL_LR) + +lr_scheduler = get_scheduler( + name=SCHEDULER_TYPE, + optimizer=dummy_optimizer, + num_warmup_steps=WARMUP_STEPS, + num_training_steps=MAX_TRAINING_STEPS, +) + +all_lrs = [] +for step in range(MAX_TRAINING_STEPS): + all_lrs.append(lr_scheduler.get_last_lr()[0]) + lr_scheduler.step() + +print(all_lrs[79]) \ No newline at end of file diff --git a/MMaDA/check_tokens.py b/MMaDA/check_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..6d816f11f451ed6ff1c62b1288393d7db33797de --- /dev/null +++ b/MMaDA/check_tokens.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +""" +첓크 방법 +========= +python check_audio_tokens.py \ + --config configs/omada_instruction_tuning.yaml \ + --samples 20 +""" + +import argparse +import random +from pathlib import Path +from typing import Iterable, Optional, Tuple, Union + +import numpy as np +import torch +from omegaconf import OmegaConf +from tqdm import tqdm +from transformers import AutoTokenizer + +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from training.data import MixedSpeechTextDataset, VideoSpeechDataset +from training.prompting_utils import UniversalPrompting +from training.utils import image_transform +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +def _to_tensor(entry: Union[torch.Tensor, np.ndarray, list, tuple, str], + vq_model: EMOVASpeechTokenizer) -> torch.Tensor: + """entryź°€ 경딜멓 encode, ģ“ėÆø ķ† ķ°ģ“ė©“ long tensor딜 ė³€ķ™˜.""" + if isinstance(entry, torch.Tensor): + tokens = entry.clone().long() + elif isinstance(entry, np.ndarray): + tokens = torch.from_numpy(entry).long() + elif isinstance(entry, (list, tuple)): + tokens = torch.as_tensor(entry, dtype=torch.long) + elif isinstance(entry, str): + # EMOVA encodeėŠ” (1, L) ė°˜ķ™˜ → 1D딜 ė³€ķ™˜ + tokens = vq_model.encode(entry).squeeze(0).long() + else: + raise TypeError(f"Unsupported token entry type: {type(entry)}") + return tokens.view(-1) + + +def _log_stats(flow: str, path: str, tokens: torch.Tensor, + codebook_size: int = 4096) -> Tuple[int, int]: + max_id = int(tokens.max().item()) + min_id = int(tokens.min().item()) + over = int((tokens >= codebook_size).sum().item()) + under = int((tokens < 0).sum().item()) + + print( + f"[{flow}] path={path} " + f"shape={tuple(tokens.shape)} " + f"min={min_id} max={max_id} " + f"<0={under} >=4096={over}" + ) + return over, under + + +def build_prompting(config) -> UniversalPrompting: + tokenizer = AutoTokenizer.from_pretrained( + config.model.omada.tokenizer_path, + padding_side="left", + ) + special_tokens = ( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + "<|i2i|>", "<|v2t|>", "<|v2s|>", "<|s2t|>", + "<|t2s|>", "<|s2s|>", "<|soa|>", "<|eoa|>", + ) + prompt = UniversalPrompting( + tokenizer, + max_text_len=config.dataset.preprocessing.max_seq_length, + max_audio_len=config.dataset.preprocessing.max_aud_length, + max_audio_len_short=config.dataset.preprocessing.max_aud_length_short, + ignore_id=-100, + cond_dropout_prob=config.training.cond_dropout_prob, + special_tokens=special_tokens, + use_reserved_token=True, + ) + return prompt + + +def sample_indices(length: int, num: int) -> Tuple[Iterable[int], int]: + """ + Returns iterable of indices and the total count that will be iterated. + If num <= 0 or num >= length, iterates through the whole dataset. + """ + if num is None or num <= 0 or num >= length: + return range(length), length + indices = random.sample(range(length), num) + return indices, len(indices) + + +@torch.no_grad() +def inspect_v2s(config, prompting, vq_model, num_samples: int): + speech_cfg = OmegaConf.to_container( + config.dataset.params.get("video_speech_dataset", {}), + resolve=True + ) or {} + dataset = VideoSpeechDataset( + transform=image_transform, + resolution=config.dataset.preprocessing.resolution, + num_frames=speech_cfg.get("num_frames_speech", 4), + video_root=speech_cfg.get( + "video_root", "/home/work/AIDAS/data/video/openvid1m/video/video" + ), + audio_root=speech_cfg.get( + "audio_root", "/home/work/AIDAS/data/video-speech" + ), + speech_dir_name=speech_cfg.get("speech_dir_name", "openvid-speech-trunc"), + index_path=speech_cfg.get( + "index_path", "/home/work/AIDAS/data/video-speech/openvid-speech.csv" + ), + sample_method=speech_cfg.get("sample_method", "uniform"), + precomputed_tokens_root=speech_cfg.get("precomputed_tokens_root"), + ) + + print(f"\n=== VideoSpeechDataset (v2s) | total={len(dataset)} ===") + total_over = total_under = 0 + indices, total = sample_indices(len(dataset), num_samples) + for idx in tqdm(indices, total=total, desc="v2s audio", unit="sample"): + sample = dataset.data[idx] + speech_path = sample["speech"] + tokens = dataset._load_precomputed_tokens(speech_path) + if tokens is not None: + tokens = tokens.long() + else: + tokens = vq_model.encode(speech_path).squeeze(0).long() + over, under = _log_stats("v2s", speech_path, tokens) + total_over += over + total_under += under + + print(f"[v2s] total >=4096: {total_over} | total <0: {total_under}") + + +@torch.no_grad() +def inspect_t2s(config, prompting, vq_model, num_samples: int): + dataset = MixedSpeechTextDataset(config.dataset.params.audio_data) + + print(f"\n=== MixedSpeechTextDataset (t2s/s2t 공용) | total={len(dataset)} ===") + total_over = total_under = 0 + indices, total = sample_indices(len(dataset), num_samples) + for idx in tqdm(indices, total=total, desc="t2s/s2t audio", unit="sample"): + sample = dataset[idx] + entry = sample["audio_path"] + if isinstance(entry, np.ndarray): + tokens = torch.from_numpy(entry).long() + path_repr = "" + elif isinstance(entry, str): + tokens = vq_model.encode(entry).squeeze(0).long() + path_repr = entry + else: + tokens = torch.as_tensor(entry, dtype=torch.long) + path_repr = "" + over, under = _log_stats("t2s/s2t-source", path_repr, tokens) + total_over += over + total_under += under + + print(f"[t2s] total >=4096: {total_over} | total <0: {total_under}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True, + help="ķ•™ģŠµģ— ģ‚¬ģš©ķ•œ YAML 설정 ķŒŒģ¼") + parser.add_argument( + "--samples", + type=int, + default=-1, + help="각 ė°ģ“ķ„°ģ…‹ģ—ģ„œ 검사할 ģƒ˜ķ”Œ 수 (<=0ģ“ė©“ 전첓 검사)", + ) + args = parser.parse_args() + + config = OmegaConf.load(args.config) + prompting = build_prompting(config) + + vq_model = EMOVASpeechTokenizer.from_pretrained( + config.model.vq_model_audio.vq_model_name + ) + vq_model.eval() + + inspect_v2s(config, prompting, vq_model, args.samples) + # inspect_t2s(config, prompting, vq_model, args.samples) + + +if __name__ == "__main__": + torch.manual_seed(0) + random.seed(0) + main() diff --git a/MMaDA/configs/mmada_demo.yaml b/MMaDA/configs/mmada_demo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52b68c6ee0e54d5b487b263cc868449f6144e938 --- /dev/null +++ b/MMaDA/configs/mmada_demo.yaml @@ -0,0 +1,95 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "demo" + name: "mmada-demo" + output_dir: "mmada-demo" + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + gen_type: "imagenet1k" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..01209}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + shuffle_buffer_size: 1000 + num_workers: 32 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 512 # for text tokens + resolution: 256 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 8000 + +training: + gradient_accumulation_steps: 4 + noise_type: "mask" + batch_size_t2i: 5 + batch_size_lm: 1 + batch_size_mmu: 2 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 500000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 20 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 1.0 + +mask_schedule: + schedule: "cosine" \ No newline at end of file diff --git a/MMaDA/configs/mmada_demo_s2t.yaml b/MMaDA/configs/mmada_demo_s2t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b22e9d656620bc1470171375bfc865e2ca55b40 --- /dev/null +++ b/MMaDA/configs/mmada_demo_s2t.yaml @@ -0,0 +1,131 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "omada-training-stage1" + name: "omada-training-stage1" + output_dir: "ckpts/omada/omada-training-stage1" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 5000 + eval_every: 10000000000 + generate_every: 1000000000 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model_image: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + ### Omada ############################################################### + vq_model_audio: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + omada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + # pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 5 # task token 3 + eoa / soa + tie_word_embeddings: False + ######################################################################### + + gradient_checkpointing: True + +dataset: + gen_type: "pass" + und_type: "pass" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + ### Omada ############################################################### + video_root: "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m" + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + audio_data: + - name: "gigaspeech" + subset: "xl" + split: "train" + - name: "librispeech" + subset: "train-clean-360" + - name: "commonvoice" + subset: "validated" + ######################################################################### + shuffle_buffer_size: 1000 + num_workers: 8 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 128 # for text tokens + max_aud_length: 256 # for audio tokens + resolution: 128 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 3000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 1 + noise_type: "mask" + batch_size_t2i: 0 + batch_size_lm: 0 + batch_size_mmu: 0 + batch_size_v2t: 2 + batch_size_s2t: 2 + batch_size_t2s: 3 + + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 200000 + max_train_epochs: 1 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 0.75 + generation_timesteps: 16 + # t2i_coeff: 0.1 + # lm_coeff: 0.1 + # mmu_coeff: 0.1 + v2t_coeff: 1.0 + t2s_coeff: 1.0 + s2t_coeff: 1.0 \ No newline at end of file diff --git a/MMaDA/configs/mmada_demo_speech.yaml b/MMaDA/configs/mmada_demo_speech.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f728527cdc8ba60d9509f449ad90090ce21e620 --- /dev/null +++ b/MMaDA/configs/mmada_demo_speech.yaml @@ -0,0 +1,101 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "demo" + name: "mmada-demo" + output_dir: "mmada-demo" + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + speech_model: + type: "emova" + speech_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + + mmada: + pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + speech_codebook_size: 4096 + num_vq_tokens: 256 + num_speech_vq_tokens: 100 + num_new_special_tokens: 3 + tie_word_embeddings: False + train_step: 25000 + + gradient_checkpointing: True + +dataset: + gen_type: "imagenet1k" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..01209}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + shuffle_buffer_size: 1000 + num_workers: 32 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 512 # for text tokens + resolution: 256 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 8000 + +training: + gradient_accumulation_steps: 4 + noise_type: "mask" + batch_size_t2i: 5 + batch_size_lm: 1 + batch_size_mmu: 2 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 500000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 20 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 1.0 + +mask_schedule: + schedule: "cosine" \ No newline at end of file diff --git a/MMaDA/configs/mmada_demo_video.yaml b/MMaDA/configs/mmada_demo_video.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5ee83ead01d3a99db92fc88a1dabae3b3a624c86 --- /dev/null +++ b/MMaDA/configs/mmada_demo_video.yaml @@ -0,0 +1,95 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "demo" + name: "mmada-demo" + output_dir: "mmada-demo" + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + gen_type: "imagenet1k" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..01209}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + shuffle_buffer_size: 1000 + num_workers: 32 + resolution: 128 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 512 # for text tokens + resolution: 256 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 8000 + +training: + gradient_accumulation_steps: 4 + noise_type: "mask" + batch_size_t2i: 5 + batch_size_lm: 1 + batch_size_mmu: 2 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 500000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 20 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 1.0 + +mask_schedule: + schedule: "cosine" \ No newline at end of file diff --git a/MMaDA/configs/mmada_demo_video_temp.yaml b/MMaDA/configs/mmada_demo_video_temp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93dfd4ce594188b096b18c77c43aef41592f952b --- /dev/null +++ b/MMaDA/configs/mmada_demo_video_temp.yaml @@ -0,0 +1,95 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "demo" + name: "mmada-demo" + output_dir: "mmada-demo" + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 900 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + gen_type: "imagenet1k" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..01209}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + shuffle_buffer_size: 1000 + num_workers: 32 + resolution: 480 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 512 # for text tokens + resolution: 480 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 8000 + +training: + gradient_accumulation_steps: 4 + noise_type: "mask" + batch_size_t2i: 5 + batch_size_lm: 1 + batch_size_mmu: 2 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 500000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 20 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 1.0 + +mask_schedule: + schedule: "cosine" \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_i2i.yaml b/MMaDA/configs/mmada_pretraining_i2i.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4304267a34079b3e58efaf555801656928c523e9 --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_i2i.yaml @@ -0,0 +1,86 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "ommda-training-i2i_256_0715" + name: "ommda-training-i2i-mmada-instruct_256_0715" + output_dir: "ommda-training-i2i-mmada-instruct_256_0715" + save_every: 5000 + eval_every: 20000 + generate_every: 5000 + num_validation_images: 20 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + val_every: 50000 + max_val_examples_t2i: 2000 + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + params: + num_workers: 0 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 256 # for text tokens + resolution: 256 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 4 + noise_type: "mask" + batch_size_i2i: 1 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 50000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 5 + generation_timesteps: 50 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 0.5 + validation_seed: 42 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_s2t.yaml b/MMaDA/configs/mmada_pretraining_s2t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f8b850f729a17d21e5d03f7eda8605d5245db52 --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_s2t.yaml @@ -0,0 +1,96 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "ommda-training-s2t" + name: "ommda-training-s2t-mmada" + output_dir: "ommda-training-s2t-mmada" + save_every: 5000 + eval_every: 20000 + generate_every: 5000 + num_validation_images: 20 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: False + val_every: 50000 + max_val_examples_t2i: 2000 + +model: + vq_model: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + + mmada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + speech_codebook_size: 4096 + # num_vq_tokens: 256 + # num_speech_vq_tokens: 250 + num_new_special_tokens: 3 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + params: + num_workers: 0 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 256 # for text tokens + resolution: 256 + center_crop: False + random_flip: False + + data: + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + name: "gigaspeech" + subset: "xl" + split: "train" + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 4 + noise_type: "mask" + batch_size_s2t: 4 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 50000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 5 + generation_timesteps: 50 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 0.5 + validation_seed: 42 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_stage1_llada_instruct.yaml b/MMaDA/configs/mmada_pretraining_stage1_llada_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a58d73f01fc6dcadf9e88b7c66fe79eb5785f186 --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_stage1_llada_instruct.yaml @@ -0,0 +1,100 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "mmada-training-stage1" + name: "mmada-training-stage1-llada-instruct" + output_dir: "mmada-training-stage1-llada-instruct" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 10000 + eval_every: 2500 + generate_every: 1000 + log_every: 50 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + mmada: + pretrained_model_path: "GSAI-ML/LLaDA-8B-Instruct" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + gen_type: "imagenet1k" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + shuffle_buffer_size: 1000 + num_workers: 32 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 128 # for text tokens + resolution: 256 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-4 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + +training: + gradient_accumulation_steps: 2 + noise_type: "mask" + batch_size_t2i: 7 + batch_size_lm: 2 + batch_size_mmu: 6 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 500000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 12 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 1.0 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_stage2_llada_instruct.yaml b/MMaDA/configs/mmada_pretraining_stage2_llada_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..799f379987b27fdd3509e5747238a641fc18aae2 --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_stage2_llada_instruct.yaml @@ -0,0 +1,109 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "mmada-training-stage2" + name: "mmada-training-stage2-llada-instruct" + output_dir: "mmada-training-stage2-llada-instruct" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 10000 + eval_every: 2500 + generate_every: 1000 + log_every: 50 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + val_every: 50 + max_val_examples_t2i: 2000 + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "path/to/your/checkpoint" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + gen_type: "t2i" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/ty/datasets/laion-aesthetics-12m-images-2" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/new_captions" + validation_prompts_file: "validation_prompts/text2image_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + shuffle_buffer_size: 1000 + num_workers: 32 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 256 # for text tokens + resolution: 256 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 2 + noise_type: "mask" + batch_size_t2i: 7 + batch_size_lm: 2 + batch_size_mmu: 3 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 1000000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 3 + generation_timesteps: 12 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 0.5 + validation_seed: 42 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_stage3_llada_instruct.yaml b/MMaDA/configs/mmada_pretraining_stage3_llada_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..970226c8d402a4e06981ea99046ffda36f268698 --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_stage3_llada_instruct.yaml @@ -0,0 +1,112 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "mmada-training-stage3" + name: "mmada-training-stage3-llada-instruct" + output_dir: "mmada-training-stage3-llada-instruct" + max_train_examples_t2i: 40000000 # + max_train_examples_mmu: 40000000 # + save_every: 10000 + eval_every: 2500 + generate_every: 1000 + log_every: 50 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + val_every: 50 + max_val_examples_t2i: 2000 + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "path/to/your/checkpoint" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + gen_type: "t2i" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: [ # + "/data_storage/shared/datasets/JourneyDB/train/imgs/data/train/imgs/{000..199}.tgz", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar", + "/data_storage/shared/datasets/text-to-image-2M/data_512_2M" + ] + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", # + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/ty/shared/datasets/3-instruct-datasets/parquet/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/ty/datasets/laion-aesthetics-12m-images-2" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/new_captions" + external_text_to_image_2M_512_caption_path: "/data_storage/shared/datasets/text-to-image-2M/data_512_2M_captions" + validation_prompts_file: "validation_prompts/text2image_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + lm_chat_validation_jsonl: "/data_storage/ty/MMaDA/lm_chat_validation/questions.jsonl" + shuffle_buffer_size: 1000 + num_workers: 32 + resolution: 512 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 512 # for text tokens 512 + resolution: 512 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 4 # 4 + noise_type: "mask" + batch_size_t2i: 4 # 3~4 + batch_size_lm: 1 + batch_size_mmu: 1 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 1000000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 3 + generation_timesteps: 12 + t2i_coeff: 1.0 + lm_coeff: 0.4 # ~0.5 + mmu_coeff: 0.5 + validation_seed: 42 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_stage3_llada_instruct_512_cot.yaml b/MMaDA/configs/mmada_pretraining_stage3_llada_instruct_512_cot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab923a7a2965dd0ff18d4601840b16cdcfa989cc --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_stage3_llada_instruct_512_cot.yaml @@ -0,0 +1,123 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "mmada-training-stage3" + name: "mmada-training-stage3-llada-instruct-512-cot-uni" + output_dir: "mmada-training-stage3-llada-instruct-512-cot-uni" + max_train_examples_t2i: 40000000 # + max_train_examples_mmu: 40000000 # + save_every: 10000 + eval_every: 2500 + generate_every: 1000 + log_every: 50 + log_grad_norm_every: 100 + # resume_from_checkpoint: False + resume_from_checkpoint: "latest" + val_every: 50 + max_val_examples_t2i: 2000 + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "path/to/your/checkpoint" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 1024 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + gen_type: "t2i" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: [ "/data_storage/shared/datasets/JourneyDB/train/imgs/data/train/imgs/{000..199}.tgz", + "/data_storage/shared/datasets/laion-aesthetics-12m-filter/{00000..00999}.tar", + # "/data_storage/shared/datasets/text-to-image-2M/data_512_2M/data_{000000..000046}.tar" + ] + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/multimodal_cot/ai2d/new_images.tar", + "/data_storage/shared/datasets/multimodal_cot/clevr/images.tar", + "/data_storage/shared/datasets/multimodal_cot/docvqa/images.tar", + "/data_storage/shared/datasets/multimodal_cot/geo/images.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar", + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/3-cot-sft/parquet/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/ty/datasets/laion-aesthetics-12m-images-2" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/new_captions" + external_text_to_image_2M_512_caption_path: "/data_storage/shared/datasets/text-to-image-2M/data_512_2M_captions" + external_ai2d_caption_path: "/data_storage/shared/datasets/multimodal_cot/ai2d/new_metadata.csv" + external_clevr_caption_path: "/data_storage/shared/datasets/multimodal_cot/clevr/metadata.csv" + external_docvqa_caption_path: "/data_storage/shared/datasets/multimodal_cot/docvqa/metadata.csv" + external_geo_caption_path: "/data_storage/shared/datasets/multimodal_cot/geo/metadata.csv" + validation_prompts_file: "validation_prompts/text2image_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + mmu_validation_prompts_file: "/data_storage/ty/MMaDA/mmu_validation/prompts.jsonl" + lm_chat_validation_jsonl: "/data_storage/ty/MMaDA/lm_chat_validation/questions.jsonl" + shuffle_buffer_size: 1000 + num_workers: 32 + resolution: 512 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 512 # for text tokens in t2i & mmu + max_lm_text_length: 1536 # for text tokens in lm/lm_chat + resolution: 512 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 4 # 4 + noise_type: "mask" + batch_size_t2i: 1 + batch_size_lm: 2 + batch_size_mmu: 1 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 1000000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 5 + generation_timesteps: 20 + t2i_coeff: 1.0 + lm_coeff: 0.5 + mmu_coeff: 0.5 + +validation: + quantative_prompts_file: "/data_storage/ty/MMaDA/validation_prompts/quantative.txt" + quantative_batch_size: 8 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_stage4_llada_instruct.yaml b/MMaDA/configs/mmada_pretraining_stage4_llada_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ad5d41453cc5dbc60256107be607f4554472816 --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_stage4_llada_instruct.yaml @@ -0,0 +1,134 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "mmada-training-stage4" + name: "mmada-training-stage4-llada-instruct" + output_dir: "mmada-training-stage4-llada-instruct" + max_train_examples_t2i: 40000000 # + max_train_examples_mmu: 40000000 # + save_every: 10000 + eval_every: 2500 + generate_every: 1000 + log_every: 50 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + val_every: 50 + max_val_examples_t2i: 2000 + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "/data_storage/ty/MMaDA/mmada-training-stage3-llada-instruct-512-cot-uni/checkpoint-210000/unwrapped_model" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 1024 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + gen_type: "t2i" + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: [ "/data_storage/shared/datasets/JourneyDB/train/imgs/data/train/imgs/{000..199}.tgz", + "/data_storage/shared/datasets/laion-aesthetics-12m-filter/{00000..00999}.tar", + # "/data_storage/shared/datasets/text-to-image-2M/data_512_2M/data_{000000..000046}.tar" + ] + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/multimodal_cot/ai2d/new_images.tar", + "/data_storage/shared/datasets/multimodal_cot/clevr/images.tar", + "/data_storage/shared/datasets/multimodal_cot/docvqa/images.tar", + "/data_storage/shared/datasets/multimodal_cot/geo/images.tar", + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + train_instruct_shards_path_or_url: "/data_storage/shared/datasets/stage4_instruct/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/ty/datasets/laion-aesthetics-12m-images-2" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/new_captions" + external_text_to_image_2M_512_caption_path: "/data_storage/shared/datasets/text-to-image-2M/data_512_2M_captions" + external_ai2d_caption_path: "/data_storage/shared/datasets/multimodal_cot/ai2d/new_metadata.csv" + external_clevr_caption_path: "/data_storage/shared/datasets/multimodal_cot/clevr/metadata.csv" + external_docvqa_caption_path: "/data_storage/shared/datasets/multimodal_cot/docvqa/metadata.csv" + external_geo_caption_path: "/data_storage/shared/datasets/multimodal_cot/geo/metadata.csv" + external_vqa_caption_path: "/data_storage/shared/datasets/LLaVA-Instruct-150K/llava_v1_5_mix665k.json" + external_clevr2_caption_path: "/data_storage/ty/datasets/Clevr_CoGenT_TrainA_70K_Complex/captions.json" + external_geo170k_caption_path: "/data_storage/ty/shared/datasets/Geo170K/Geo170K/all.json" + vqa_images_path: "/data_storage/shared/datasets/LLaVA-Instruct-150K-images" + clevr2_images_path: "/data_storage/ty/datasets/Clevr_CoGenT_TrainA_70K_Complex/images" + geo170k_images_path: "/data_storage/ty/shared/datasets/Geo170K/Geo170K/images" + validation_prompts_file: "validation_prompts/text2image_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + mmu_validation_prompts_file: "/data_storage/ty/MMaDA/mmu_validation/prompts_with_vqa.json" + lm_chat_validation_jsonl: "/data_storage/ty/MMaDA/lm_chat_validation/questions.jsonl" + shuffle_buffer_size: 1000 + num_workers: 16 + resolution: 512 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 512 # for text tokens in t2i & mmu + max_lm_text_length: 1536 # for text tokens in lm/lm_chat + resolution: 512 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 4 # 4 + noise_type: "mask" + batch_size_t2i: 1 + batch_size_lm: 2 + batch_size_mmu: 1 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 1000000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 5 + generation_timesteps: 20 + t2i_coeff: 0.05 + lm_coeff: 0.6 + mmu_coeff: 0.4 + cot_in_mmu_coeff: 3.5 + vqa_in_mmu_coeff: 5.5 + clevr2_in_mmu_coeff: 0.5 + geo170k_in_mmu_coeff: 0.5 + base_in_lm_coeff: 0.02 + instruct_in_lm_coeff: 0.98 + +validation: + quantative_prompts_file: "/data_storage/ty/MMaDA/validation_prompts/quantative.txt" + quantative_batch_size: 8 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_t2s.yaml b/MMaDA/configs/mmada_pretraining_t2s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3dd75b703f95e25aa8754bef3fda1962743cf9d5 --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_t2s.yaml @@ -0,0 +1,96 @@ +wandb: + entity: null + # run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "ommda-training-t2s" + name: "ommda-training-t2s-mmada" + output_dir: "ommda-training-t2s-mmada" + save_every: 5000 + eval_every: 20000 + generate_every: 5000 + num_validation_images: 20 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + val_every: 50000 + max_val_examples_t2i: 2000 + +model: + vq_model: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + + mmada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + speech_codebook_size: 4096 + # num_vq_tokens: 256 + # num_speech_vq_tokens: 250 + num_new_special_tokens: 3 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + params: + num_workers: 0 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 256 # for text tokens + resolution: 256 + center_crop: False + random_flip: False + + data: + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + name: "gigaspeech" + subset: "xl" + split: "train" + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-4 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 2500 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 4 + noise_type: "mask" + batch_size_s2t: 4 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 50000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 5 + generation_timesteps: 50 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 0.5 + validation_seed: 42 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_v2s.yaml b/MMaDA/configs/mmada_pretraining_v2s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d51f8b584db10ad49e62edcbdbcc3c4fb534012 --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_v2s.yaml @@ -0,0 +1,133 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "omada-training-stage1" + name: "omada-training-stage1_ignore_SP" + output_dir: "ckpts/omada/omada-training-stage1_v2s_test" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 5000 + eval_every: 5000 + generate_every: 1000000000 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model_image: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + ### Omada ############################################################### + vq_model_audio: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + omada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + # pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 5 # task token 3 + eoa / soa + tie_word_embeddings: False + ######################################################################### + + gradient_checkpointing: True + +dataset: + gen_type: "pass" + und_type: "pass" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + ### Omada ############################################################### + video_root: "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m" + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + audio_data: + - name: "gigaspeech" + subset: "xl" + split: "train" + - name: "librispeech" + subset: "train-clean-360" + - name: "commonvoice" + subset: "validated" + ######################################################################### + shuffle_buffer_size: 1000 + num_workers: 8 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 128 # for text tokens + max_aud_length: 384 # for audio tokens + resolution: 128 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + # learning_rate: 1e-4 + learning_rate: 0.000079 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 0 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 1 + noise_type: "mask" + batch_size_t2i: 0 + batch_size_lm: 0 + batch_size_mmu: 0 + batch_size_v2t: 0 + batch_size_s2t: 0 + batch_size_t2s: 0 + batch_size_v2s: 1 + + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 630000 # 2epoch + max_train_epochs: NONE + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 16 + # t2i_coeff: 0.1 + # lm_coeff: 0.1 + # mmu_coeff: 0.1 + v2t_coeff: 0.2 + t2s_coeff: 1.0 + s2t_coeff: 0.2 \ No newline at end of file diff --git a/MMaDA/configs/mmada_pretraining_v2t.yaml b/MMaDA/configs/mmada_pretraining_v2t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b6db11daee2db5f6106d5bdfa1ee448707513db --- /dev/null +++ b/MMaDA/configs/mmada_pretraining_v2t.yaml @@ -0,0 +1,88 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "mmada-training-v2t" + name: "mmada-training-stage3-llada-instruct-v2t" + output_dir: "mmada-training-stage3-llada-instruct-v2t-special-token-1e-5" + max_train_examples_t2i: 40000000 # + max_train_examples_mmu: 40000000 ddd# + save_every: 1000 + eval_every: 2500 + generate_every: 1000 + log_every: 10 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + val_every: 50 + max_val_examples_t2i: 2000 + +model: + vq_model: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + + mmada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 0 + tie_word_embeddings: False + + gradient_checkpointing: True + +dataset: + und_type: "captioning" + combined_loader_mode: "max_size_cycle" + + preprocessing: + max_seq_length: 128 # for text tokens 512 + resolution: 128 + center_crop: False + random_flip: False + + params: + num_workers: 32 + + + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 4 # 4 + noise_type: "mask" + batch_size_v2t: 4 + batch_size_mmu: 1 + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 1000000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 3 + generation_timesteps: 12 + mmu_coeff: 1.0 + validation_seed: 42 \ No newline at end of file diff --git a/MMaDA/configs/omada_instruction_tuning.yaml b/MMaDA/configs/omada_instruction_tuning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b93abb0abfadb0c69e1768bb0dc339d312fb2cf3 --- /dev/null +++ b/MMaDA/configs/omada_instruction_tuning.yaml @@ -0,0 +1,200 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "omada-instruction-tuning" + name: "omada-instruction-tuning" + output_dir: "ckpts/omada/omada-instruction-tuning-tv_sacle_0.7" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 5000 + eval_every: 10000 + generate_every: 1000000000 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model_image: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + ### Omada ############################################################### + vq_model_audio: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + omada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + # pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + pretrained_model_path: "//home/work/AIDAS/ckpts/merged_model/hf_common_merge_alpha_999_scale_0p7" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 3 # v2s, s2s, i2i + tie_word_embeddings: False + ######################################################################### + + gradient_checkpointing: True + +dataset: + gen_type: "pass" + und_type: "pass" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + ### Omada ############################################################### + video_root: "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m" + video_speech_dataset: + sample_mode: "exclusive" + use_precomputed_tokens: true + precomputed_tokens_root: "/home/work/AIDAS/cache/speech_tokens" + llavavid_path: "/home/work/AIDAS/data/video/LLaVA-Video-178K" + llavavid_local_files_only: true + llavavid_skip_configs: + - "llava_hound" + - "0_30_s_activitynetqa" + - "30_60_s_activitynetqa" + - "1_2_m_activitynetqa" + - "2_3_m_activitynetqa" + - "0_30_s_activitynet" + - "30_60_s_activitynet" + - "1_2_m_activitynet" + - "2_3_m_activitynet" + llavavid_skip_video_patterns: + - "activitynet" + # video_dataset_name: "openvid1m" + hqedit_split: "train" + t2i_dataset: "text2image2m+openimage_i2i+hqedit" + t2i_split: "train" + t2i_dataset_name: "jackyhate/text-to-image-2M" + t2i_local_files_only: true + openimage_i2i: + sft_jsonl: "/home/work/AIDAS/data/openimage_source_images/sft_with_local_source_image_path.jsonl" + pref_jsonl: "/home/work/AIDAS/data/openimage_source_images/pref_with_local_source_image_path.jsonl" + multi_turn_jsonl: "/home/work/AIDAS/data/openimage_source_images/multi-turn_with_local_source_image_path.jsonl" + image_root: "/home/work/AIDAS/data/nano_edited_images" + prefer_summarized_text: true + pref_positive_only: true + skip_missing: true + max_samples_per_source: null + max_total_samples: null + seed: 42 + hf_instruction_lm: + split: "train" + max_samples_per_source: 1000000 + max_total_samples: 20000000 + seed: 42 + speech2speech: + - name: "instructs2s" + use_precomputed_tokens: false + precomputed_tokens_root: "/home/work/AIDAS/cache/instructs2s_tokens" + mmu_interleaved: + local_data_root: /home/work/AIDAS/data/TIGER-Lab/Mantis-Instruct + local_files_only: true + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + audio_data: + # - name: "gigaspeech" + # subset: "xl" + # split: "train" + - name: "librispeech" + subset: "train-clean-360" + use_precomputed_tokens: true + precomputed_tokens_root: "/home/work/AIDAS/cache/librispeech_tokens" + # - name: "commonvoice" + # subset: "validated" + ######################################################################### + shuffle_buffer_size: 1000 + num_workers: 0 + resolution: 256 + # resolution: 16 + pin_memory: False + persistent_workers: False + dataloader_timeout: 0 + + + speech_token_cache: + enable: true + root: "cache/speech_tokens" + max_items_in_memory: 4096 + + preprocessing: + max_seq_length: 128 # for text tokens + max_aud_length: 384 # for audio tokens + max_aud_length_short: 256 # for short audio tokens + resolution: 128 # for video tokens + # max_seq_length: 16 # for text tokens + # max_aud_length: 16 # for audio tokens + # resolution: 16 # for video tokens + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + # learning_rate: 0.00004859840219369731 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + # warmup_steps: 1000 + warmup_steps: 0 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 1 + noise_type: "mask" + batch_size_t2i: 1 + batch_size_lm: 1 + batch_size_mmu: 1 + batch_size_v2t: 1 + batch_size_v2s: 1 + batch_size_s2t: 2 + batch_size_t2s: 2 + batch_size_s2s: 2 + + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 500000 + max_train_epochs: NONE + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 3.5 + generation_timesteps: 42 + + t2i_coeff: 2.5 + i2i_coeff: 2.5 + lm_coeff: 2.5 + mmu_coeff: 0.1 + v2t_coeff: 0.2 + v2s_coeff: 2.0 + t2s_coeff: 2.5 + s2t_coeff: 0.5 + s2s_coeff: 3.0 diff --git a/MMaDA/configs/omada_pretraining_stage1-2.yaml b/MMaDA/configs/omada_pretraining_stage1-2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36e21ae2a9cc4dc5a053d2ba111afcad39b11c72 --- /dev/null +++ b/MMaDA/configs/omada_pretraining_stage1-2.yaml @@ -0,0 +1,131 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "omada-training-stage1" + name: "omada-training-stage1" + output_dir: "ckpts/omada/omada-training-stage1_2nd" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 5000 + eval_every: 5000 + generate_every: 1000000000 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model_image: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + ### Omada ############################################################### + vq_model_audio: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + omada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + # pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 5 # task token 3 + eoa / soa + tie_word_embeddings: False + ######################################################################### + + gradient_checkpointing: True + +dataset: + gen_type: "pass" + und_type: "pass" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + ### Omada ############################################################### + video_root: "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m" + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + audio_data: + - name: "gigaspeech" + subset: "xl" + split: "train" + - name: "librispeech" + subset: "train-clean-360" + - name: "commonvoice" + subset: "validated" + ######################################################################### + shuffle_buffer_size: 1000 + num_workers: 8 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 128 # for text tokens + max_aud_length: 256 # for audio tokens + resolution: 128 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 5e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 0 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 1 + noise_type: "mask" + batch_size_t2i: 0 + batch_size_lm: 0 + batch_size_mmu: 0 + batch_size_v2t: 1 + batch_size_s2t: 1 + batch_size_t2s: 5 + + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 315000 # 2epoch + max_train_epochs: NONE + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 0.0 + generation_timesteps: 64 + # t2i_coeff: 0.1 + # lm_coeff: 0.1 + # mmu_coeff: 0.1 + v2t_coeff: 0.1 + t2s_coeff: 1.0 + s2t_coeff: 0.1 \ No newline at end of file diff --git a/MMaDA/configs/omada_pretraining_stage1-3.yaml b/MMaDA/configs/omada_pretraining_stage1-3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e767a9bcdc849c96631af43e70400fefdf07873 --- /dev/null +++ b/MMaDA/configs/omada_pretraining_stage1-3.yaml @@ -0,0 +1,132 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "omada-training-stage1" + name: "omada-training-stage1_ignore_SP" + output_dir: "ckpts/omada/omada-training-stage1_7th" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 5000 + eval_every: 5000 + generate_every: 1000000000 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model_image: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + ### Omada ############################################################### + vq_model_audio: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + omada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + # pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 5 # task token 3 + eoa / soa + tie_word_embeddings: False + ######################################################################### + + gradient_checkpointing: True + +dataset: + gen_type: "pass" + und_type: "pass" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + ### Omada ############################################################### + video_root: "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m" + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + audio_data: + - name: "gigaspeech" + subset: "xl" + split: "train" + - name: "librispeech" + subset: "train-clean-360" + - name: "commonvoice" + subset: "validated" + ######################################################################### + shuffle_buffer_size: 1000 + num_workers: 8 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 128 # for text tokens + max_aud_length: 384 # for audio tokens + resolution: 128 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + # learning_rate: 1e-4 + learning_rate: 0.000079 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 0 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 1 + noise_type: "mask" + batch_size_t2i: 0 + batch_size_lm: 0 + batch_size_mmu: 0 + batch_size_v2t: 1 + batch_size_s2t: 1 + batch_size_t2s: 5 + + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 630000 # 2epoch + max_train_epochs: NONE + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 16 + # t2i_coeff: 0.1 + # lm_coeff: 0.1 + # mmu_coeff: 0.1 + v2t_coeff: 0.2 + t2s_coeff: 1.0 + s2t_coeff: 0.2 \ No newline at end of file diff --git a/MMaDA/configs/omada_pretraining_stage1-4.yaml b/MMaDA/configs/omada_pretraining_stage1-4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93b744e84a55e3efc81aea11a66899aea732cacb --- /dev/null +++ b/MMaDA/configs/omada_pretraining_stage1-4.yaml @@ -0,0 +1,132 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "omada-training-stage1" + name: "omada-training-stage1_ignore_SP" + output_dir: "ckpts/omada/omada-training-stage1_5th" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 5000 + eval_every: 5000 + generate_every: 1000000000 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model_image: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + ### Omada ############################################################### + vq_model_audio: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + omada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + # pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 5 # task token 3 + eoa / soa + tie_word_embeddings: False + ######################################################################### + + gradient_checkpointing: True + +dataset: + gen_type: "pass" + und_type: "pass" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + ### Omada ############################################################### + video_root: "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m" + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + audio_data: + - name: "gigaspeech" + subset: "xl" + split: "train" + - name: "librispeech" + subset: "train-clean-360" + - name: "commonvoice" + subset: "validated" + ######################################################################### + shuffle_buffer_size: 1000 + num_workers: 4 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 128 # for text tokens + max_aud_length: 256 # for audio tokens + resolution: 128 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + # learning_rate: 5e-6 + learning_rate: 0.00000483 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 0 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 1 + noise_type: "mask" + batch_size_t2i: 0 + batch_size_lm: 0 + batch_size_mmu: 0 + batch_size_v2t: 1 + batch_size_s2t: 1 + batch_size_t2s: 5 + + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 630000 # 2epoch + max_train_epochs: NONE + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 16 + # t2i_coeff: 0.1 + # lm_coeff: 0.1 + # mmu_coeff: 0.1 + v2t_coeff: 0.2 + t2s_coeff: 1.0 + s2t_coeff: 0.2 \ No newline at end of file diff --git a/MMaDA/configs/omada_pretraining_stage1.yaml b/MMaDA/configs/omada_pretraining_stage1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9d74ad61f31827f8264c14ec9db0a2a71831b38 --- /dev/null +++ b/MMaDA/configs/omada_pretraining_stage1.yaml @@ -0,0 +1,131 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "omada-training-stage1" + name: "omada-training-stage1" + output_dir: "ckpts/omada/omada-training-stage1" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 5000 + eval_every: 10000000000 + generate_every: 1000000000 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model_image: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + ### Omada ############################################################### + vq_model_audio: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + omada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + # pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 5 # task token 3 + eoa / soa + tie_word_embeddings: False + ######################################################################### + + gradient_checkpointing: True + +dataset: + gen_type: "pass" + und_type: "pass" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + ### Omada ############################################################### + video_root: "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m" + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + audio_data: + - name: "gigaspeech" + subset: "xl" + split: "train" + - name: "librispeech" + subset: "train-clean-360" + - name: "commonvoice" + subset: "validated" + ######################################################################### + shuffle_buffer_size: 1000 + num_workers: 8 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 128 # for text tokens + max_aud_length: 256 # for audio tokens + resolution: 128 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-5 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 3000 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 1 + noise_type: "mask" + batch_size_t2i: 0 + batch_size_lm: 0 + batch_size_mmu: 0 + batch_size_v2t: 2 + batch_size_s2t: 2 + batch_size_t2s: 3 + + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 200000 + max_train_epochs: 1 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 12 + # t2i_coeff: 0.1 + # lm_coeff: 0.1 + # mmu_coeff: 0.1 + v2t_coeff: 1.0 + t2s_coeff: 1.0 + s2t_coeff: 1.0 \ No newline at end of file diff --git a/MMaDA/configs/omada_pretraining_v2t_inst.yaml b/MMaDA/configs/omada_pretraining_v2t_inst.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5057df20ed4e0e2624fcfe0e018e9549edad0e27 --- /dev/null +++ b/MMaDA/configs/omada_pretraining_v2t_inst.yaml @@ -0,0 +1,132 @@ +wandb: + entity: null +# run_id: askkz9i2 + resume: 'auto' + +experiment: + project: "omada-training-v2t_inst" + name: "omada-training-v2t_inst" + output_dir: "ckpts/omada/omada-training-v2t_inst" + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 + save_every: 5000 + eval_every: 5000 + generate_every: 1000000000 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: "latest" + +model: + vq_model_image: + type: "magvitv2" + vq_model_name: "showlab/magvitv2" + ### Omada ############################################################### + vq_model_audio: + type: "emova" + vq_model_name: "Emova-ollm/emova_speech_tokenizer_hf" + omada: + tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct" + pretrained_model_path: "Gen-Verse/MMaDA-8B-MixCoT" + # pretrained_model_path: "Gen-Verse/MMaDA-8B-Base" + w_clip_vit: False + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 5 # task token 3 + eoa / soa + tie_word_embeddings: False + ######################################################################### + + gradient_checkpointing: True + +dataset: + gen_type: "pass" + und_type: "pass" + combined_loader_mode: "max_size_cycle" + params: + train_t2i_shards_path_or_url: "/data_storage/shared/datasets/imagenet-1k/data/train" + train_mmu_shards_path_or_url: [ "/data_storage/shared/datasets/SA-1B/sa_{000000..000999}.tar", + "/data_storage/shared/datasets/cc12m/raw/raw/{0000..0999}.tar", + "/data_storage/shared/datasets/laion-aesthetics-12m/{00000..00999}.tar" + ] + train_lm_shards_path_or_url: "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + add_caption_prompt: True + external_caption_path: "/data_storage/shared/datasets/SAM-LLaVA-Captions10M" + external_journeydb_caption_path: "/data_storage/shared/datasets/journeydb_anno/train_journeydb_anno.json" + external_laion12m_caption_path: "/data_storage/shared/datasets/laion-aesthetic-12m-captions" + external_cc12m_caption_path: "/data_storage/shared/datasets/cc12m/captions" + validation_prompts_file: "validation_prompts/imagenet_prompts.txt" + mmu_image_root: "/data_storage/ty/MMaDA/mmu_validation" + ### Omada ############################################################### + video_root: "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m" + # subset for gigaspeech: xs, xl + # subset for librispeech: train-clean-360, train-clean-100 + # subset for commonvoice: validated, invalidated + audio_data: + - name: "gigaspeech" + subset: "xl" + split: "train" + - name: "librispeech" + subset: "train-clean-360" + - name: "commonvoice" + subset: "validated" + ######################################################################### + shuffle_buffer_size: 1000 + num_workers: 8 + resolution: 256 + pin_memory: True + persistent_workers: True + + preprocessing: + max_seq_length: 128 # for text tokens + max_aud_length: 384 # for audio tokens + resolution: 128 + center_crop: False + random_flip: False + +optimizer: + name: adamw + params: # default adamw params + # learning_rate: 1e-4 + learning_rate: 0.000079 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + +lr_scheduler: + scheduler: "cosine" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 0 + min_lr_scale: 0.1 + +training: + gradient_accumulation_steps: 1 + noise_type: "mask" + batch_size_t2i: 0 + batch_size_lm: 0 + batch_size_mmu: 0 + batch_size_v2t: 1 + batch_size_s2t: 1 + batch_size_t2s: 5 + + mixed_precision: "bf16" + enable_tf32: True + seed: 10086 + max_train_steps: 630000 # 2epoch + max_train_epochs: NONE + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 1.5 + generation_timesteps: 16 + # t2i_coeff: 0.1 + # lm_coeff: 0.1 + # mmu_coeff: 0.1 + v2t_coeff: 0.2 + t2s_coeff: 1.0 + s2t_coeff: 0.2 \ No newline at end of file diff --git a/MMaDA/debug_speech_dataloader.py b/MMaDA/debug_speech_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..a63d794994d516774baf7aee042613d88b56b78a --- /dev/null +++ b/MMaDA/debug_speech_dataloader.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +"""Utility to reproduce and debug the speech DataLoader used in training. + +This script pulls the speech dataset configuration from the Omada +instruction-tuning config, instantiates the same `MixedSpeechTextDataset`, and +iterates a configurable number of batches while measuring how long each fetch +takes. Use it to spot slow or stuck samples without launching the full training +job. + +Typical usage:: + + python AIDAS/MMaDA/script/debug_speech_dataloader.py \ + --config AIDAS/MMaDA/configs/omada_instruction_tuning.yaml \ + --flow s2t --max-batches 5 --num-workers 1 --timeout 0 + +Pass `--inspect-items` for a direct `dataset[idx]` sweep when a specific sample +seems suspicious. +""" + +from __future__ import annotations + +import argparse +import itertools +import logging +import sys +import time +from pathlib import Path +from typing import Any, Iterable, List + +from omegaconf import OmegaConf +from torch.utils.data import DataLoader + +from MMaDA.training.data import MixedSpeechTextDataset + + +def _collate_fn_audio(batch: List[dict[str, Any]]) -> dict[str, List[Any]]: + """Match the collate function used in training for speech flows.""" + + return { + "audio_path": [item["audio_path"] for item in batch], + "text": [item["text"] for item in batch], + "audio_tokens": [item.get("audio_tokens") for item in batch], + } + + +def _as_list_of_dicts(cfg_fragment: Any) -> List[dict[str, Any]]: + container = OmegaConf.to_container(cfg_fragment, resolve=True) + if not isinstance(container, Iterable): # pragma: no cover - sanity guard + raise TypeError("audio_data config must be a list of dataset dicts") + return list(container) # type: ignore[arg-type] + + +def _build_dataset(cfg) -> MixedSpeechTextDataset: + dataset_cfg = cfg.dataset.params + audio_data_cfg = _as_list_of_dicts(dataset_cfg.audio_data) + return MixedSpeechTextDataset(audio_data_cfg) + + +def _log_batch_summary(idx: int, batch: dict[str, List[Any]], elapsed: float) -> None: + audio_paths = batch.get("audio_path", []) + sample = audio_paths[0] if audio_paths else "" + logging.info( + "batch=%d size=%d elapsed=%.2fs sample=%s", + idx, + len(audio_paths), + elapsed, + sample, + ) + + +def _inspect_items(dataset: MixedSpeechTextDataset, max_items: int) -> None: + logging.info("Inspecting individual dataset items (max=%d)", max_items) + for idx in itertools.islice(range(len(dataset)), max_items): + tick = time.perf_counter() + try: + item = dataset[idx] + except Exception as exc: # pragma: no cover - diagnostic path + logging.error("idx=%d failed: %s", idx, exc) + continue + elapsed = time.perf_counter() - tick + logging.info( + "idx=%d elapsed=%.2fs path=%s text_len=%d tokens=%s", + idx, + elapsed, + item.get("audio_path"), + len(item.get("text", "")), + "cached" if item.get("audio_tokens") is not None else "None", + ) + + +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--config", + type=Path, + default=Path("AIDAS/MMaDA/configs/omada_instruction_tuning.yaml"), + help="Path to the training config YAML", + ) + parser.add_argument( + "--flow", + choices=["s2t", "t2s"], + default="s2t", + help="Which speech flow's batch size defaults to use", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Override batch size (defaults to config.training.batch_size_)", + ) + parser.add_argument( + "--num-workers", + type=int, + default=None, + help="Override DataLoader workers (defaults to config.dataset.params.num_workers)", + ) + parser.add_argument( + "--persistent-workers", + action="store_true", + help="Enable persistent workers regardless of config", + ) + parser.add_argument( + "--timeout", + type=float, + default=None, + help="DataLoader timeout in seconds (defaults to config.dataset.params.dataloader_timeout)", + ) + parser.add_argument( + "--max-batches", + type=int, + default=10, + help="Number of batches to iterate (0 means run through the entire dataset)", + ) + parser.add_argument( + "--inspect-items", + type=int, + default=0, + help="If >0, bypass the DataLoader and inspect this many individual dataset items first", + ) + parser.add_argument( + "--prefetch-factor", + type=int, + default=None, + help="Optional override for DataLoader prefetch_factor", + ) + parser.add_argument( + "--log-level", + default="INFO", + help="Logging level", + ) + return parser.parse_args(argv) + + +def main(argv: List[str]) -> int: + args = parse_args(argv) + logging.basicConfig( + level=getattr(logging, args.log_level.upper(), logging.INFO), + format="%(asctime)s | %(levelname)s | %(message)s", + ) + + cfg = OmegaConf.load(args.config) + dataset = _build_dataset(cfg) + + if args.inspect_items: + _inspect_items(dataset, args.inspect_items) + + dataset_params = cfg.dataset.params + batch_size = args.batch_size or getattr(cfg.training, f"batch_size_{args.flow}") + num_workers = args.num_workers if args.num_workers is not None else dataset_params.num_workers + timeout = args.timeout if args.timeout is not None else dataset_params.dataloader_timeout + + if num_workers == 0: + persistent_workers = False + else: + persistent_workers = args.persistent_workers or bool(dataset_params.persistent_workers) + + dataloader_kwargs = { + "dataset": dataset, + "batch_size": batch_size, + "shuffle": False, + "num_workers": num_workers, + "drop_last": True, + "pin_memory": bool(dataset_params.pin_memory), + "timeout": timeout, + "persistent_workers": persistent_workers, + "collate_fn": _collate_fn_audio, + } + if args.prefetch_factor is not None and num_workers > 0: + dataloader_kwargs["prefetch_factor"] = args.prefetch_factor + + logging.info( + "Starting DataLoader debug: batch_size=%d num_workers=%d timeout=%s persistent=%s", + batch_size, + num_workers, + timeout, + persistent_workers, + ) + + dataloader = DataLoader(**dataloader_kwargs) + + max_batches = args.max_batches + iterator = iter(dataloader) + + processed = 0 + while True: + if max_batches and processed >= max_batches: + break + tick = time.perf_counter() + try: + batch = next(iterator) + except StopIteration: + logging.info("Reached end of DataLoader after %d batches", processed) + break + elapsed = time.perf_counter() - tick + _log_batch_summary(processed, batch, elapsed) + processed += 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/MMaDA/eval_ASR_TTS/test.py b/MMaDA/eval_ASR_TTS/test.py new file mode 100644 index 0000000000000000000000000000000000000000..aef3d83056ee0d4ff0502bc18e2fea625d6d236e --- /dev/null +++ b/MMaDA/eval_ASR_TTS/test.py @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "54c0a618-750f-4bf0-8cdb-c2dda158c433", + "metadata": {}, + "outputs": [], + "source": [ + "import argparse\n", + "import json\n", + "import os\n", + "import editdistance" + ] + }, + { + "cell_type": "markdown", + "id": "658bb863-f147-444e-8b14-466e1999d15f", + "metadata": {}, + "source": [ + "# Speech -> Text" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7e4d5e19-e526-4b33-aa03-0a4cc68abd90", + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_WER(recognized_text_list, groundtruth_text_list):\n", + " word_num = 0.0\n", + " scores = 0.0\n", + " for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list):\n", + " if len(recognized_text) > 1000:\n", + " print(recognized_text)\n", + " continue\n", + " recognized_word_list = recognized_text.split()\n", + " groundtruth_word_list = groundtruth_text.split()\n", + " current_word_num = len(groundtruth_word_list)\n", + " word_num += current_word_num\n", + " # Compute Levenstein's distance\n", + " current_score = editdistance.eval(recognized_word_list, groundtruth_word_list)\n", + " scores += current_score\n", + " WER = scores / word_num\n", + " return WER, scores, word_num\n", + "\n", + "\n", + "def evaluate_asr(prediction_list, ground_truth_list):\n", + " wer, scores_wer, word_num_wer = calculate_WER(prediction_list, ground_truth_list)\n", + " print(f'wer: {wer}, scores_wer: {scores_wer}, word_num_wer: {word_num_wer}')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "05f4a95c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WER (demo): 0.4375 | word errors: 7.0 | total words: 16.0\n" + ] + } + ], + "source": [ + "\n", + "gt_0 = \"Hello. We are AIDAS laboratory.\"\n", + "gt_1 = \"Hello. Let's build an omni model diffusion foundation model.\"\n", + "gt_2 = \"Pretty intense.\"\n", + "\n", + "pred_0 = \"hello, we are AIDAS laboratory.\"\n", + "pred_1 = \"hello let's build an omni model diffusion foundation model\"\n", + "pred_2 = \"pretty intense\"\n", + "\n", + "groundtruth_text_list = [gt_0, gt_1, gt_2]\n", + "recognized_text_list = [pred_0, pred_1, pred_2]\n", + "\n", + "wer, errors, words = calculate_WER(recognized_text_list, groundtruth_text_list)\n", + "print(f\"WER (demo): {wer:.4f} | word errors: {errors} | total words: {words}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3635f492-2ae2-4ef4-9321-36d08aa6645e", + "metadata": {}, + "source": [ + "# Text -> Speech (with normalizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1ac74c9a", + "metadata": {}, + "outputs": [], + "source": [ + "# Environment & deps check (safe to run multiple times)\n", + "import sys, os, importlib\n", + "from pathlib import Path\n", + "\n", + "\n", + "# optional: ensure packages (comment out if you manage env separately)\n", + "try:\n", + " import editdistance # used by calculate_WER\n", + "except Exception:\n", + " print(\"Installing editdistance...\")\n", + " %pip -q install editdistance\n", + "\n", + "try:\n", + " import more_itertools # required by english.py normalizer\n", + "except Exception:\n", + " print(\"Installing more-itertools...\")\n", + " %pip -q install more-itertools\n", + "\n", + "# local modules\n", + "from whisper_asr.whisper_asr import load_whisper_model, EN_ASR_WER\n", + "from whisper_asr.normalizers.english import EnglishTextNormalizer # EMOVA-style normalizer\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4ffd26a0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Device set to use cuda\n", + "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily be entirely accurate and will have caveats. More information: https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(..., ignore_warning=True). To use Whisper for long-form transcription, use rather the model's `generate` method directly as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. Long-form Transcription).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "whisper model loaded!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1 [00:00\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = self.clean(s).lower() + + if self.split_letters: + s = " ".join(regex.findall(r"\X", s, regex.U)) + + s = re.sub( + r"\s+", " ", s + ) # replace any successive whitespace characters with a space + + return s diff --git a/MMaDA/eval_ASR_TTS/whisper_asr/normalizers/english.json b/MMaDA/eval_ASR_TTS/whisper_asr/normalizers/english.json new file mode 100644 index 0000000000000000000000000000000000000000..74a1c3521d9af4a84a0aa65d95b0859138640799 --- /dev/null +++ b/MMaDA/eval_ASR_TTS/whisper_asr/normalizers/english.json @@ -0,0 +1,1741 @@ +{ + "accessorise": "accessorize", + "accessorised": "accessorized", + "accessorises": "accessorizes", + "accessorising": "accessorizing", + "acclimatisation": "acclimatization", + "acclimatise": "acclimatize", + "acclimatised": "acclimatized", + "acclimatises": "acclimatizes", + "acclimatising": "acclimatizing", + "accoutrements": "accouterments", + "aeon": "eon", + "aeons": "eons", + "aerogramme": "aerogram", + "aerogrammes": "aerograms", + "aeroplane": "airplane", + "aeroplanes": "airplanes", + "aesthete": "esthete", + "aesthetes": "esthetes", + "aesthetic": "esthetic", + "aesthetically": "esthetically", + "aesthetics": "esthetics", + "aetiology": "etiology", + "ageing": "aging", + "aggrandisement": "aggrandizement", + "agonise": "agonize", + "agonised": "agonized", + "agonises": "agonizes", + "agonising": "agonizing", + "agonisingly": "agonizingly", + "almanack": "almanac", + "almanacks": "almanacs", + "aluminium": "aluminum", + "amortisable": "amortizable", + "amortisation": "amortization", + "amortisations": "amortizations", + "amortise": "amortize", + "amortised": "amortized", + "amortises": "amortizes", + "amortising": "amortizing", + "amphitheatre": "amphitheater", + "amphitheatres": "amphitheaters", + "anaemia": "anemia", + "anaemic": "anemic", + "anaesthesia": "anesthesia", + "anaesthetic": "anesthetic", + "anaesthetics": "anesthetics", + "anaesthetise": "anesthetize", + "anaesthetised": "anesthetized", + "anaesthetises": "anesthetizes", + "anaesthetising": "anesthetizing", + "anaesthetist": "anesthetist", + "anaesthetists": "anesthetists", + "anaesthetize": "anesthetize", + "anaesthetized": "anesthetized", + "anaesthetizes": "anesthetizes", + "anaesthetizing": "anesthetizing", + "analogue": "analog", + "analogues": "analogs", + "analyse": "analyze", + "analysed": "analyzed", + "analyses": "analyzes", + "analysing": "analyzing", + "anglicise": "anglicize", + "anglicised": "anglicized", + "anglicises": "anglicizes", + "anglicising": "anglicizing", + "annualised": "annualized", + "antagonise": "antagonize", + "antagonised": "antagonized", + "antagonises": "antagonizes", + "antagonising": "antagonizing", + "apologise": "apologize", + "apologised": "apologized", + "apologises": "apologizes", + "apologising": "apologizing", + "appal": "appall", + "appals": "appalls", + "appetiser": "appetizer", + "appetisers": "appetizers", + "appetising": "appetizing", + "appetisingly": "appetizingly", + "arbour": "arbor", + "arbours": "arbors", + "archeological": "archaeological", + "archaeologically": "archeologically", + "archaeologist": "archeologist", + "archaeologists": "archeologists", + "archaeology": "archeology", + "ardour": "ardor", + "armour": "armor", + "armoured": "armored", + "armourer": "armorer", + "armourers": "armorers", + "armouries": "armories", + "armoury": "armory", + "artefact": "artifact", + "artefacts": "artifacts", + "authorise": "authorize", + "authorised": "authorized", + "authorises": "authorizes", + "authorising": "authorizing", + "axe": "ax", + "backpedalled": "backpedaled", + "backpedalling": "backpedaling", + "bannister": "banister", + "bannisters": "banisters", + "baptise": "baptize", + "baptised": "baptized", + "baptises": "baptizes", + "baptising": "baptizing", + "bastardise": "bastardize", + "bastardised": "bastardized", + "bastardises": "bastardizes", + "bastardising": "bastardizing", + "battleax": "battleaxe", + "baulk": "balk", + "baulked": "balked", + "baulking": "balking", + "baulks": "balks", + "bedevilled": "bedeviled", + "bedevilling": "bedeviling", + "behaviour": "behavior", + "behavioural": "behavioral", + "behaviourism": "behaviorism", + "behaviourist": "behaviorist", + "behaviourists": "behaviorists", + "behaviours": "behaviors", + "behove": "behoove", + "behoved": "behooved", + "behoves": "behooves", + "bejewelled": "bejeweled", + "belabour": "belabor", + "belaboured": "belabored", + "belabouring": "belaboring", + "belabours": "belabors", + "bevelled": "beveled", + "bevvies": "bevies", + "bevvy": "bevy", + "biassed": "biased", + "biassing": "biasing", + "bingeing": "binging", + "bougainvillaea": "bougainvillea", + "bougainvillaeas": "bougainvilleas", + "bowdlerise": "bowdlerize", + "bowdlerised": "bowdlerized", + "bowdlerises": "bowdlerizes", + "bowdlerising": "bowdlerizing", + "breathalyse": "breathalyze", + "breathalysed": "breathalyzed", + "breathalyser": "breathalyzer", + "breathalysers": "breathalyzers", + "breathalyses": "breathalyzes", + "breathalysing": "breathalyzing", + "brutalise": "brutalize", + "brutalised": "brutalized", + "brutalises": "brutalizes", + "brutalising": "brutalizing", + "busses": "buses", + "bussing": "busing", + "caesarean": "cesarean", + "caesareans": "cesareans", + "calibre": "caliber", + "calibres": "calibers", + "calliper": "caliper", + "callipers": "calipers", + "callisthenics": "calisthenics", + "canalise": "canalize", + "canalised": "canalized", + "canalises": "canalizes", + "canalising": "canalizing", + "cancelation": "cancellation", + "cancelations": "cancellations", + "cancelled": "canceled", + "cancelling": "canceling", + "candour": "candor", + "cannibalise": "cannibalize", + "cannibalised": "cannibalized", + "cannibalises": "cannibalizes", + "cannibalising": "cannibalizing", + "canonise": "canonize", + "canonised": "canonized", + "canonises": "canonizes", + "canonising": "canonizing", + "capitalise": "capitalize", + "capitalised": "capitalized", + "capitalises": "capitalizes", + "capitalising": "capitalizing", + "caramelise": "caramelize", + "caramelised": "caramelized", + "caramelises": "caramelizes", + "caramelising": "caramelizing", + "carbonise": "carbonize", + "carbonised": "carbonized", + "carbonises": "carbonizes", + "carbonising": "carbonizing", + "carolled": "caroled", + "carolling": "caroling", + "catalogue": "catalog", + "catalogued": "cataloged", + "catalogues": "catalogs", + "cataloguing": "cataloging", + "catalyse": "catalyze", + "catalysed": "catalyzed", + "catalyses": "catalyzes", + "catalysing": "catalyzing", + "categorise": "categorize", + "categorised": "categorized", + "categorises": "categorizes", + "categorising": "categorizing", + "cauterise": "cauterize", + "cauterised": "cauterized", + "cauterises": "cauterizes", + "cauterising": "cauterizing", + "cavilled": "caviled", + "cavilling": "caviling", + "centigramme": "centigram", + "centigrammes": "centigrams", + "centilitre": "centiliter", + "centilitres": "centiliters", + "centimetre": "centimeter", + "centimetres": "centimeters", + "centralise": "centralize", + "centralised": "centralized", + "centralises": "centralizes", + "centralising": "centralizing", + "centre": "center", + "centred": "centered", + "centrefold": "centerfold", + "centrefolds": "centerfolds", + "centrepiece": "centerpiece", + "centrepieces": "centerpieces", + "centres": "centers", + "channelled": "channeled", + "channelling": "channeling", + "characterise": "characterize", + "characterised": "characterized", + "characterises": "characterizes", + "characterising": "characterizing", + "cheque": "check", + "chequebook": "checkbook", + "chequebooks": "checkbooks", + "chequered": "checkered", + "cheques": "checks", + "chilli": "chili", + "chimaera": "chimera", + "chimaeras": "chimeras", + "chiselled": "chiseled", + "chiselling": "chiseling", + "circularise": "circularize", + "circularised": "circularized", + "circularises": "circularizes", + "circularising": "circularizing", + "civilise": "civilize", + "civilised": "civilized", + "civilises": "civilizes", + "civilising": "civilizing", + "clamour": "clamor", + "clamoured": "clamored", + "clamouring": "clamoring", + "clamours": "clamors", + "clangour": "clangor", + "clarinettist": "clarinetist", + "clarinettists": "clarinetists", + "collectivise": "collectivize", + "collectivised": "collectivized", + "collectivises": "collectivizes", + "collectivising": "collectivizing", + "colonisation": "colonization", + "colonise": "colonize", + "colonised": "colonized", + "coloniser": "colonizer", + "colonisers": "colonizers", + "colonises": "colonizes", + "colonising": "colonizing", + "colour": "color", + "colourant": "colorant", + "colourants": "colorants", + "coloured": "colored", + "coloureds": "coloreds", + "colourful": "colorful", + "colourfully": "colorfully", + "colouring": "coloring", + "colourize": "colorize", + "colourized": "colorized", + "colourizes": "colorizes", + "colourizing": "colorizing", + "colourless": "colorless", + "colours": "colors", + "commercialise": "commercialize", + "commercialised": "commercialized", + "commercialises": "commercializes", + "commercialising": "commercializing", + "compartmentalise": "compartmentalize", + "compartmentalised": "compartmentalized", + "compartmentalises": "compartmentalizes", + "compartmentalising": "compartmentalizing", + "computerise": "computerize", + "computerised": "computerized", + "computerises": "computerizes", + "computerising": "computerizing", + "conceptualise": "conceptualize", + "conceptualised": "conceptualized", + "conceptualises": "conceptualizes", + "conceptualising": "conceptualizing", + "connexion": "connection", + "connexions": "connections", + "contextualise": "contextualize", + "contextualised": "contextualized", + "contextualises": "contextualizes", + "contextualising": "contextualizing", + "cosier": "cozier", + "cosies": "cozies", + "cosiest": "coziest", + "cosily": "cozily", + "cosiness": "coziness", + "cosy": "cozy", + "councillor": "councilor", + "councillors": "councilors", + "counselled": "counseled", + "counselling": "counseling", + "counsellor": "counselor", + "counsellors": "counselors", + "crenelated": "crenellated", + "criminalise": "criminalize", + "criminalised": "criminalized", + "criminalises": "criminalizes", + "criminalising": "criminalizing", + "criticise": "criticize", + "criticised": "criticized", + "criticises": "criticizes", + "criticising": "criticizing", + "crueller": "crueler", + "cruellest": "cruelest", + "crystallisation": "crystallization", + "crystallise": "crystallize", + "crystallised": "crystallized", + "crystallises": "crystallizes", + "crystallising": "crystallizing", + "cudgelled": "cudgeled", + "cudgelling": "cudgeling", + "customise": "customize", + "customised": "customized", + "customises": "customizes", + "customising": "customizing", + "cypher": "cipher", + "cyphers": "ciphers", + "decentralisation": "decentralization", + "decentralise": "decentralize", + "decentralised": "decentralized", + "decentralises": "decentralizes", + "decentralising": "decentralizing", + "decriminalisation": "decriminalization", + "decriminalise": "decriminalize", + "decriminalised": "decriminalized", + "decriminalises": "decriminalizes", + "decriminalising": "decriminalizing", + "defence": "defense", + "defenceless": "defenseless", + "defences": "defenses", + "dehumanisation": "dehumanization", + "dehumanise": "dehumanize", + "dehumanised": "dehumanized", + "dehumanises": "dehumanizes", + "dehumanising": "dehumanizing", + "demeanour": "demeanor", + "demilitarisation": "demilitarization", + "demilitarise": "demilitarize", + "demilitarised": "demilitarized", + "demilitarises": "demilitarizes", + "demilitarising": "demilitarizing", + "demobilisation": "demobilization", + "demobilise": "demobilize", + "demobilised": "demobilized", + "demobilises": "demobilizes", + "demobilising": "demobilizing", + "democratisation": "democratization", + "democratise": "democratize", + "democratised": "democratized", + "democratises": "democratizes", + "democratising": "democratizing", + "demonise": "demonize", + "demonised": "demonized", + "demonises": "demonizes", + "demonising": "demonizing", + "demoralisation": "demoralization", + "demoralise": "demoralize", + "demoralised": "demoralized", + "demoralises": "demoralizes", + "demoralising": "demoralizing", + "denationalisation": "denationalization", + "denationalise": "denationalize", + "denationalised": "denationalized", + "denationalises": "denationalizes", + "denationalising": "denationalizing", + "deodorise": "deodorize", + "deodorised": "deodorized", + "deodorises": "deodorizes", + "deodorising": "deodorizing", + "depersonalise": "depersonalize", + "depersonalised": "depersonalized", + "depersonalises": "depersonalizes", + "depersonalising": "depersonalizing", + "deputise": "deputize", + "deputised": "deputized", + "deputises": "deputizes", + "deputising": "deputizing", + "desensitisation": "desensitization", + "desensitise": "desensitize", + "desensitised": "desensitized", + "desensitises": "desensitizes", + "desensitising": "desensitizing", + "destabilisation": "destabilization", + "destabilise": "destabilize", + "destabilised": "destabilized", + "destabilises": "destabilizes", + "destabilising": "destabilizing", + "dialled": "dialed", + "dialling": "dialing", + "dialogue": "dialog", + "dialogues": "dialogs", + "diarrhoea": "diarrhea", + "digitise": "digitize", + "digitised": "digitized", + "digitises": "digitizes", + "digitising": "digitizing", + "disc": "disk", + "discolour": "discolor", + "discoloured": "discolored", + "discolouring": "discoloring", + "discolours": "discolors", + "discs": "disks", + "disembowelled": "disemboweled", + "disembowelling": "disemboweling", + "disfavour": "disfavor", + "dishevelled": "disheveled", + "dishonour": "dishonor", + "dishonourable": "dishonorable", + "dishonourably": "dishonorably", + "dishonoured": "dishonored", + "dishonouring": "dishonoring", + "dishonours": "dishonors", + "disorganisation": "disorganization", + "disorganised": "disorganized", + "distil": "distill", + "distils": "distills", + "dramatisation": "dramatization", + "dramatisations": "dramatizations", + "dramatise": "dramatize", + "dramatised": "dramatized", + "dramatises": "dramatizes", + "dramatising": "dramatizing", + "draught": "draft", + "draughtboard": "draftboard", + "draughtboards": "draftboards", + "draughtier": "draftier", + "draughtiest": "draftiest", + "draughts": "drafts", + "draughtsman": "draftsman", + "draughtsmanship": "draftsmanship", + "draughtsmen": "draftsmen", + "draughtswoman": "draftswoman", + "draughtswomen": "draftswomen", + "draughty": "drafty", + "drivelled": "driveled", + "drivelling": "driveling", + "duelled": "dueled", + "duelling": "dueling", + "economise": "economize", + "economised": "economized", + "economises": "economizes", + "economising": "economizing", + "edoema": "edema", + "editorialise": "editorialize", + "editorialised": "editorialized", + "editorialises": "editorializes", + "editorialising": "editorializing", + "empathise": "empathize", + "empathised": "empathized", + "empathises": "empathizes", + "empathising": "empathizing", + "emphasise": "emphasize", + "emphasised": "emphasized", + "emphasises": "emphasizes", + "emphasising": "emphasizing", + "enamelled": "enameled", + "enamelling": "enameling", + "enamoured": "enamored", + "encyclopaedia": "encyclopedia", + "encyclopaedias": "encyclopedias", + "encyclopaedic": "encyclopedic", + "endeavour": "endeavor", + "endeavoured": "endeavored", + "endeavouring": "endeavoring", + "endeavours": "endeavors", + "energise": "energize", + "energised": "energized", + "energises": "energizes", + "energising": "energizing", + "enrol": "enroll", + "enrols": "enrolls", + "enthral": "enthrall", + "enthrals": "enthralls", + "epaulette": "epaulet", + "epaulettes": "epaulets", + "epicentre": "epicenter", + "epicentres": "epicenters", + "epilogue": "epilog", + "epilogues": "epilogs", + "epitomise": "epitomize", + "epitomised": "epitomized", + "epitomises": "epitomizes", + "epitomising": "epitomizing", + "equalisation": "equalization", + "equalise": "equalize", + "equalised": "equalized", + "equaliser": "equalizer", + "equalisers": "equalizers", + "equalises": "equalizes", + "equalising": "equalizing", + "eulogise": "eulogize", + "eulogised": "eulogized", + "eulogises": "eulogizes", + "eulogising": "eulogizing", + "evangelise": "evangelize", + "evangelised": "evangelized", + "evangelises": "evangelizes", + "evangelising": "evangelizing", + "exorcise": "exorcize", + "exorcised": "exorcized", + "exorcises": "exorcizes", + "exorcising": "exorcizing", + "extemporisation": "extemporization", + "extemporise": "extemporize", + "extemporised": "extemporized", + "extemporises": "extemporizes", + "extemporising": "extemporizing", + "externalisation": "externalization", + "externalisations": "externalizations", + "externalise": "externalize", + "externalised": "externalized", + "externalises": "externalizes", + "externalising": "externalizing", + "factorise": "factorize", + "factorised": "factorized", + "factorises": "factorizes", + "factorising": "factorizing", + "faecal": "fecal", + "faeces": "feces", + "familiarisation": "familiarization", + "familiarise": "familiarize", + "familiarised": "familiarized", + "familiarises": "familiarizes", + "familiarising": "familiarizing", + "fantasise": "fantasize", + "fantasised": "fantasized", + "fantasises": "fantasizes", + "fantasising": "fantasizing", + "favour": "favor", + "favourable": "favorable", + "favourably": "favorably", + "favoured": "favored", + "favouring": "favoring", + "favourite": "favorite", + "favourites": "favorites", + "favouritism": "favoritism", + "favours": "favors", + "feminise": "feminize", + "feminised": "feminized", + "feminises": "feminizes", + "feminising": "feminizing", + "fertilisation": "fertilization", + "fertilise": "fertilize", + "fertilised": "fertilized", + "fertiliser": "fertilizer", + "fertilisers": "fertilizers", + "fertilises": "fertilizes", + "fertilising": "fertilizing", + "fervour": "fervor", + "fibre": "fiber", + "fibreglass": "fiberglass", + "fibres": "fibers", + "fictionalisation": "fictionalization", + "fictionalisations": "fictionalizations", + "fictionalise": "fictionalize", + "fictionalised": "fictionalized", + "fictionalises": "fictionalizes", + "fictionalising": "fictionalizing", + "fillet": "filet", + "filleted": "fileted", + "filleting": "fileting", + "fillets": "filets", + "finalisation": "finalization", + "finalise": "finalize", + "finalised": "finalized", + "finalises": "finalizes", + "finalising": "finalizing", + "flautist": "flutist", + "flautists": "flutists", + "flavour": "flavor", + "flavoured": "flavored", + "flavouring": "flavoring", + "flavourings": "flavorings", + "flavourless": "flavorless", + "flavours": "flavors", + "flavoursome": "flavorsome", + "flyer / flier": "flier / flyer", + "foetal": "fetal", + "foetid": "fetid", + "foetus": "fetus", + "foetuses": "fetuses", + "formalisation": "formalization", + "formalise": "formalize", + "formalised": "formalized", + "formalises": "formalizes", + "formalising": "formalizing", + "fossilisation": "fossilization", + "fossilise": "fossilize", + "fossilised": "fossilized", + "fossilises": "fossilizes", + "fossilising": "fossilizing", + "fraternisation": "fraternization", + "fraternise": "fraternize", + "fraternised": "fraternized", + "fraternises": "fraternizes", + "fraternising": "fraternizing", + "fulfil": "fulfill", + "fulfilment": "fulfillment", + "fulfils": "fulfills", + "funnelled": "funneled", + "funnelling": "funneling", + "galvanise": "galvanize", + "galvanised": "galvanized", + "galvanises": "galvanizes", + "galvanising": "galvanizing", + "gambolled": "gamboled", + "gambolling": "gamboling", + "gaol": "jail", + "gaolbird": "jailbird", + "gaolbirds": "jailbirds", + "gaolbreak": "jailbreak", + "gaolbreaks": "jailbreaks", + "gaoled": "jailed", + "gaoler": "jailer", + "gaolers": "jailers", + "gaoling": "jailing", + "gaols": "jails", + "gasses": "gases", + "gage": "gauge", + "gaged": "gauged", + "gages": "gauges", + "gaging": "gauging", + "generalisation": "generalization", + "generalisations": "generalizations", + "generalise": "generalize", + "generalised": "generalized", + "generalises": "generalizes", + "generalising": "generalizing", + "ghettoise": "ghettoize", + "ghettoised": "ghettoized", + "ghettoises": "ghettoizes", + "ghettoising": "ghettoizing", + "gipsies": "gypsies", + "glamorise": "glamorize", + "glamorised": "glamorized", + "glamorises": "glamorizes", + "glamorising": "glamorizing", + "glamor": "glamour", + "globalisation": "globalization", + "globalise": "globalize", + "globalised": "globalized", + "globalises": "globalizes", + "globalising": "globalizing", + "glueing": "gluing", + "goitre": "goiter", + "goitres": "goiters", + "gonorrhoea": "gonorrhea", + "gramme": "gram", + "grammes": "grams", + "gravelled": "graveled", + "grey": "gray", + "greyed": "grayed", + "greying": "graying", + "greyish": "grayish", + "greyness": "grayness", + "greys": "grays", + "grovelled": "groveled", + "grovelling": "groveling", + "groyne": "groin", + "groynes": "groins", + "gruelling": "grueling", + "gruellingly": "gruelingly", + "gryphon": "griffin", + "gryphons": "griffins", + "gynaecological": "gynecological", + "gynaecologist": "gynecologist", + "gynaecologists": "gynecologists", + "gynaecology": "gynecology", + "haematological": "hematological", + "haematologist": "hematologist", + "haematologists": "hematologists", + "haematology": "hematology", + "haemoglobin": "hemoglobin", + "haemophilia": "hemophilia", + "haemophiliac": "hemophiliac", + "haemophiliacs": "hemophiliacs", + "haemorrhage": "hemorrhage", + "haemorrhaged": "hemorrhaged", + "haemorrhages": "hemorrhages", + "haemorrhaging": "hemorrhaging", + "haemorrhoids": "hemorrhoids", + "harbour": "harbor", + "harboured": "harbored", + "harbouring": "harboring", + "harbours": "harbors", + "harmonisation": "harmonization", + "harmonise": "harmonize", + "harmonised": "harmonized", + "harmonises": "harmonizes", + "harmonising": "harmonizing", + "homoeopath": "homeopath", + "homoeopathic": "homeopathic", + "homoeopaths": "homeopaths", + "homoeopathy": "homeopathy", + "homogenise": "homogenize", + "homogenised": "homogenized", + "homogenises": "homogenizes", + "homogenising": "homogenizing", + "honour": "honor", + "honourable": "honorable", + "honourably": "honorably", + "honoured": "honored", + "honouring": "honoring", + "honours": "honors", + "hospitalisation": "hospitalization", + "hospitalise": "hospitalize", + "hospitalised": "hospitalized", + "hospitalises": "hospitalizes", + "hospitalising": "hospitalizing", + "humanise": "humanize", + "humanised": "humanized", + "humanises": "humanizes", + "humanising": "humanizing", + "humour": "humor", + "humoured": "humored", + "humouring": "humoring", + "humourless": "humorless", + "humours": "humors", + "hybridise": "hybridize", + "hybridised": "hybridized", + "hybridises": "hybridizes", + "hybridising": "hybridizing", + "hypnotise": "hypnotize", + "hypnotised": "hypnotized", + "hypnotises": "hypnotizes", + "hypnotising": "hypnotizing", + "hypothesise": "hypothesize", + "hypothesised": "hypothesized", + "hypothesises": "hypothesizes", + "hypothesising": "hypothesizing", + "idealisation": "idealization", + "idealise": "idealize", + "idealised": "idealized", + "idealises": "idealizes", + "idealising": "idealizing", + "idolise": "idolize", + "idolised": "idolized", + "idolises": "idolizes", + "idolising": "idolizing", + "immobilisation": "immobilization", + "immobilise": "immobilize", + "immobilised": "immobilized", + "immobiliser": "immobilizer", + "immobilisers": "immobilizers", + "immobilises": "immobilizes", + "immobilising": "immobilizing", + "immortalise": "immortalize", + "immortalised": "immortalized", + "immortalises": "immortalizes", + "immortalising": "immortalizing", + "immunisation": "immunization", + "immunise": "immunize", + "immunised": "immunized", + "immunises": "immunizes", + "immunising": "immunizing", + "impanelled": "impaneled", + "impanelling": "impaneling", + "imperilled": "imperiled", + "imperilling": "imperiling", + "individualise": "individualize", + "individualised": "individualized", + "individualises": "individualizes", + "individualising": "individualizing", + "industrialise": "industrialize", + "industrialised": "industrialized", + "industrialises": "industrializes", + "industrialising": "industrializing", + "inflexion": "inflection", + "inflexions": "inflections", + "initialise": "initialize", + "initialised": "initialized", + "initialises": "initializes", + "initialising": "initializing", + "initialled": "initialed", + "initialling": "initialing", + "instal": "install", + "instalment": "installment", + "instalments": "installments", + "instals": "installs", + "instil": "instill", + "instils": "instills", + "institutionalisation": "institutionalization", + "institutionalise": "institutionalize", + "institutionalised": "institutionalized", + "institutionalises": "institutionalizes", + "institutionalising": "institutionalizing", + "intellectualise": "intellectualize", + "intellectualised": "intellectualized", + "intellectualises": "intellectualizes", + "intellectualising": "intellectualizing", + "internalisation": "internalization", + "internalise": "internalize", + "internalised": "internalized", + "internalises": "internalizes", + "internalising": "internalizing", + "internationalisation": "internationalization", + "internationalise": "internationalize", + "internationalised": "internationalized", + "internationalises": "internationalizes", + "internationalising": "internationalizing", + "ionisation": "ionization", + "ionise": "ionize", + "ionised": "ionized", + "ioniser": "ionizer", + "ionisers": "ionizers", + "ionises": "ionizes", + "ionising": "ionizing", + "italicise": "italicize", + "italicised": "italicized", + "italicises": "italicizes", + "italicising": "italicizing", + "itemise": "itemize", + "itemised": "itemized", + "itemises": "itemizes", + "itemising": "itemizing", + "jeopardise": "jeopardize", + "jeopardised": "jeopardized", + "jeopardises": "jeopardizes", + "jeopardising": "jeopardizing", + "jewelled": "jeweled", + "jeweller": "jeweler", + "jewellers": "jewelers", + "jewellery": "jewelry", + "judgement": "judgment", + "kilogramme": "kilogram", + "kilogrammes": "kilograms", + "kilometre": "kilometer", + "kilometres": "kilometers", + "labelled": "labeled", + "labelling": "labeling", + "labour": "labor", + "laboured": "labored", + "labourer": "laborer", + "labourers": "laborers", + "labouring": "laboring", + "labours": "labors", + "lacklustre": "lackluster", + "legalisation": "legalization", + "legalise": "legalize", + "legalised": "legalized", + "legalises": "legalizes", + "legalising": "legalizing", + "legitimise": "legitimize", + "legitimised": "legitimized", + "legitimises": "legitimizes", + "legitimising": "legitimizing", + "leukaemia": "leukemia", + "levelled": "leveled", + "leveller": "leveler", + "levellers": "levelers", + "levelling": "leveling", + "libelled": "libeled", + "libelling": "libeling", + "libellous": "libelous", + "liberalisation": "liberalization", + "liberalise": "liberalize", + "liberalised": "liberalized", + "liberalises": "liberalizes", + "liberalising": "liberalizing", + "licence": "license", + "licenced": "licensed", + "licences": "licenses", + "licencing": "licensing", + "likeable": "likable", + "lionisation": "lionization", + "lionise": "lionize", + "lionised": "lionized", + "lionises": "lionizes", + "lionising": "lionizing", + "liquidise": "liquidize", + "liquidised": "liquidized", + "liquidiser": "liquidizer", + "liquidisers": "liquidizers", + "liquidises": "liquidizes", + "liquidising": "liquidizing", + "litre": "liter", + "litres": "liters", + "localise": "localize", + "localised": "localized", + "localises": "localizes", + "localising": "localizing", + "louvre": "louver", + "louvred": "louvered", + "louvres": "louvers", + "lustre": "luster", + "magnetise": "magnetize", + "magnetised": "magnetized", + "magnetises": "magnetizes", + "magnetising": "magnetizing", + "manoeuvrability": "maneuverability", + "manoeuvrable": "maneuverable", + "manoeuvre": "maneuver", + "manoeuvred": "maneuvered", + "manoeuvres": "maneuvers", + "manoeuvring": "maneuvering", + "manoeuvrings": "maneuverings", + "marginalisation": "marginalization", + "marginalise": "marginalize", + "marginalised": "marginalized", + "marginalises": "marginalizes", + "marginalising": "marginalizing", + "marshalled": "marshaled", + "marshalling": "marshaling", + "marvelled": "marveled", + "marvelling": "marveling", + "marvellous": "marvelous", + "marvellously": "marvelously", + "materialisation": "materialization", + "materialise": "materialize", + "materialised": "materialized", + "materialises": "materializes", + "materialising": "materializing", + "maximisation": "maximization", + "maximise": "maximize", + "maximised": "maximized", + "maximises": "maximizes", + "maximising": "maximizing", + "meagre": "meager", + "mechanisation": "mechanization", + "mechanise": "mechanize", + "mechanised": "mechanized", + "mechanises": "mechanizes", + "mechanising": "mechanizing", + "mediaeval": "medieval", + "memorialise": "memorialize", + "memorialised": "memorialized", + "memorialises": "memorializes", + "memorialising": "memorializing", + "memorise": "memorize", + "memorised": "memorized", + "memorises": "memorizes", + "memorising": "memorizing", + "mesmerise": "mesmerize", + "mesmerised": "mesmerized", + "mesmerises": "mesmerizes", + "mesmerising": "mesmerizing", + "metabolise": "metabolize", + "metabolised": "metabolized", + "metabolises": "metabolizes", + "metabolising": "metabolizing", + "metre": "meter", + "metres": "meters", + "micrometre": "micrometer", + "micrometres": "micrometers", + "militarise": "militarize", + "militarised": "militarized", + "militarises": "militarizes", + "militarising": "militarizing", + "milligramme": "milligram", + "milligrammes": "milligrams", + "millilitre": "milliliter", + "millilitres": "milliliters", + "millimetre": "millimeter", + "millimetres": "millimeters", + "miniaturisation": "miniaturization", + "miniaturise": "miniaturize", + "miniaturised": "miniaturized", + "miniaturises": "miniaturizes", + "miniaturising": "miniaturizing", + "minibusses": "minibuses", + "minimise": "minimize", + "minimised": "minimized", + "minimises": "minimizes", + "minimising": "minimizing", + "misbehaviour": "misbehavior", + "misdemeanour": "misdemeanor", + "misdemeanours": "misdemeanors", + "misspelt": "misspelled", + "mitre": "miter", + "mitres": "miters", + "mobilisation": "mobilization", + "mobilise": "mobilize", + "mobilised": "mobilized", + "mobilises": "mobilizes", + "mobilising": "mobilizing", + "modelled": "modeled", + "modeller": "modeler", + "modellers": "modelers", + "modelling": "modeling", + "modernise": "modernize", + "modernised": "modernized", + "modernises": "modernizes", + "modernising": "modernizing", + "moisturise": "moisturize", + "moisturised": "moisturized", + "moisturiser": "moisturizer", + "moisturisers": "moisturizers", + "moisturises": "moisturizes", + "moisturising": "moisturizing", + "monologue": "monolog", + "monologues": "monologs", + "monopolisation": "monopolization", + "monopolise": "monopolize", + "monopolised": "monopolized", + "monopolises": "monopolizes", + "monopolising": "monopolizing", + "moralise": "moralize", + "moralised": "moralized", + "moralises": "moralizes", + "moralising": "moralizing", + "motorised": "motorized", + "mould": "mold", + "moulded": "molded", + "moulder": "molder", + "mouldered": "moldered", + "mouldering": "moldering", + "moulders": "molders", + "mouldier": "moldier", + "mouldiest": "moldiest", + "moulding": "molding", + "mouldings": "moldings", + "moulds": "molds", + "mouldy": "moldy", + "moult": "molt", + "moulted": "molted", + "moulting": "molting", + "moults": "molts", + "moustache": "mustache", + "moustached": "mustached", + "moustaches": "mustaches", + "moustachioed": "mustachioed", + "multicoloured": "multicolored", + "nationalisation": "nationalization", + "nationalisations": "nationalizations", + "nationalise": "nationalize", + "nationalised": "nationalized", + "nationalises": "nationalizes", + "nationalising": "nationalizing", + "naturalisation": "naturalization", + "naturalise": "naturalize", + "naturalised": "naturalized", + "naturalises": "naturalizes", + "naturalising": "naturalizing", + "neighbour": "neighbor", + "neighbourhood": "neighborhood", + "neighbourhoods": "neighborhoods", + "neighbouring": "neighboring", + "neighbourliness": "neighborliness", + "neighbourly": "neighborly", + "neighbours": "neighbors", + "neutralisation": "neutralization", + "neutralise": "neutralize", + "neutralised": "neutralized", + "neutralises": "neutralizes", + "neutralising": "neutralizing", + "normalisation": "normalization", + "normalise": "normalize", + "normalised": "normalized", + "normalises": "normalizes", + "normalising": "normalizing", + "odour": "odor", + "odourless": "odorless", + "odours": "odors", + "oesophagus": "esophagus", + "oesophaguses": "esophaguses", + "oestrogen": "estrogen", + "offence": "offense", + "offences": "offenses", + "omelette": "omelet", + "omelettes": "omelets", + "optimise": "optimize", + "optimised": "optimized", + "optimises": "optimizes", + "optimising": "optimizing", + "organisation": "organization", + "organisational": "organizational", + "organisations": "organizations", + "organise": "organize", + "organised": "organized", + "organiser": "organizer", + "organisers": "organizers", + "organises": "organizes", + "organising": "organizing", + "orthopaedic": "orthopedic", + "orthopaedics": "orthopedics", + "ostracise": "ostracize", + "ostracised": "ostracized", + "ostracises": "ostracizes", + "ostracising": "ostracizing", + "outmanoeuvre": "outmaneuver", + "outmanoeuvred": "outmaneuvered", + "outmanoeuvres": "outmaneuvers", + "outmanoeuvring": "outmaneuvering", + "overemphasise": "overemphasize", + "overemphasised": "overemphasized", + "overemphasises": "overemphasizes", + "overemphasising": "overemphasizing", + "oxidisation": "oxidization", + "oxidise": "oxidize", + "oxidised": "oxidized", + "oxidises": "oxidizes", + "oxidising": "oxidizing", + "paederast": "pederast", + "paederasts": "pederasts", + "paediatric": "pediatric", + "paediatrician": "pediatrician", + "paediatricians": "pediatricians", + "paediatrics": "pediatrics", + "paedophile": "pedophile", + "paedophiles": "pedophiles", + "paedophilia": "pedophilia", + "palaeolithic": "paleolithic", + "palaeontologist": "paleontologist", + "palaeontologists": "paleontologists", + "palaeontology": "paleontology", + "panelled": "paneled", + "panelling": "paneling", + "panellist": "panelist", + "panellists": "panelists", + "paralyse": "paralyze", + "paralysed": "paralyzed", + "paralyses": "paralyzes", + "paralysing": "paralyzing", + "parcelled": "parceled", + "parcelling": "parceling", + "parlour": "parlor", + "parlours": "parlors", + "particularise": "particularize", + "particularised": "particularized", + "particularises": "particularizes", + "particularising": "particularizing", + "passivisation": "passivization", + "passivise": "passivize", + "passivised": "passivized", + "passivises": "passivizes", + "passivising": "passivizing", + "pasteurisation": "pasteurization", + "pasteurise": "pasteurize", + "pasteurised": "pasteurized", + "pasteurises": "pasteurizes", + "pasteurising": "pasteurizing", + "patronise": "patronize", + "patronised": "patronized", + "patronises": "patronizes", + "patronising": "patronizing", + "patronisingly": "patronizingly", + "pedalled": "pedaled", + "pedalling": "pedaling", + "pedestrianisation": "pedestrianization", + "pedestrianise": "pedestrianize", + "pedestrianised": "pedestrianized", + "pedestrianises": "pedestrianizes", + "pedestrianising": "pedestrianizing", + "penalise": "penalize", + "penalised": "penalized", + "penalises": "penalizes", + "penalising": "penalizing", + "pencilled": "penciled", + "pencilling": "penciling", + "personalise": "personalize", + "personalised": "personalized", + "personalises": "personalizes", + "personalising": "personalizing", + "pharmacopoeia": "pharmacopeia", + "pharmacopoeias": "pharmacopeias", + "philosophise": "philosophize", + "philosophised": "philosophized", + "philosophises": "philosophizes", + "philosophising": "philosophizing", + "philtre": "filter", + "philtres": "filters", + "phoney": "phony", + "plagiarise": "plagiarize", + "plagiarised": "plagiarized", + "plagiarises": "plagiarizes", + "plagiarising": "plagiarizing", + "plough": "plow", + "ploughed": "plowed", + "ploughing": "plowing", + "ploughman": "plowman", + "ploughmen": "plowmen", + "ploughs": "plows", + "ploughshare": "plowshare", + "ploughshares": "plowshares", + "polarisation": "polarization", + "polarise": "polarize", + "polarised": "polarized", + "polarises": "polarizes", + "polarising": "polarizing", + "politicisation": "politicization", + "politicise": "politicize", + "politicised": "politicized", + "politicises": "politicizes", + "politicising": "politicizing", + "popularisation": "popularization", + "popularise": "popularize", + "popularised": "popularized", + "popularises": "popularizes", + "popularising": "popularizing", + "pouffe": "pouf", + "pouffes": "poufs", + "practise": "practice", + "practised": "practiced", + "practises": "practices", + "practising": "practicing", + "praesidium": "presidium", + "praesidiums": "presidiums", + "pressurisation": "pressurization", + "pressurise": "pressurize", + "pressurised": "pressurized", + "pressurises": "pressurizes", + "pressurising": "pressurizing", + "pretence": "pretense", + "pretences": "pretenses", + "primaeval": "primeval", + "prioritisation": "prioritization", + "prioritise": "prioritize", + "prioritised": "prioritized", + "prioritises": "prioritizes", + "prioritising": "prioritizing", + "privatisation": "privatization", + "privatisations": "privatizations", + "privatise": "privatize", + "privatised": "privatized", + "privatises": "privatizes", + "privatising": "privatizing", + "professionalisation": "professionalization", + "professionalise": "professionalize", + "professionalised": "professionalized", + "professionalises": "professionalizes", + "professionalising": "professionalizing", + "programme": "program", + "programmes": "programs", + "prologue": "prolog", + "prologues": "prologs", + "propagandise": "propagandize", + "propagandised": "propagandized", + "propagandises": "propagandizes", + "propagandising": "propagandizing", + "proselytise": "proselytize", + "proselytised": "proselytized", + "proselytiser": "proselytizer", + "proselytisers": "proselytizers", + "proselytises": "proselytizes", + "proselytising": "proselytizing", + "psychoanalyse": "psychoanalyze", + "psychoanalysed": "psychoanalyzed", + "psychoanalyses": "psychoanalyzes", + "psychoanalysing": "psychoanalyzing", + "publicise": "publicize", + "publicised": "publicized", + "publicises": "publicizes", + "publicising": "publicizing", + "pulverisation": "pulverization", + "pulverise": "pulverize", + "pulverised": "pulverized", + "pulverises": "pulverizes", + "pulverising": "pulverizing", + "pummelled": "pummel", + "pummelling": "pummeled", + "pyjama": "pajama", + "pyjamas": "pajamas", + "pzazz": "pizzazz", + "quarrelled": "quarreled", + "quarrelling": "quarreling", + "radicalise": "radicalize", + "radicalised": "radicalized", + "radicalises": "radicalizes", + "radicalising": "radicalizing", + "rancour": "rancor", + "randomise": "randomize", + "randomised": "randomized", + "randomises": "randomizes", + "randomising": "randomizing", + "rationalisation": "rationalization", + "rationalisations": "rationalizations", + "rationalise": "rationalize", + "rationalised": "rationalized", + "rationalises": "rationalizes", + "rationalising": "rationalizing", + "ravelled": "raveled", + "ravelling": "raveling", + "realisable": "realizable", + "realisation": "realization", + "realisations": "realizations", + "realise": "realize", + "realised": "realized", + "realises": "realizes", + "realising": "realizing", + "recognisable": "recognizable", + "recognisably": "recognizably", + "recognisance": "recognizance", + "recognise": "recognize", + "recognised": "recognized", + "recognises": "recognizes", + "recognising": "recognizing", + "reconnoitre": "reconnoiter", + "reconnoitred": "reconnoitered", + "reconnoitres": "reconnoiters", + "reconnoitring": "reconnoitering", + "refuelled": "refueled", + "refuelling": "refueling", + "regularisation": "regularization", + "regularise": "regularize", + "regularised": "regularized", + "regularises": "regularizes", + "regularising": "regularizing", + "remodelled": "remodeled", + "remodelling": "remodeling", + "remould": "remold", + "remoulded": "remolded", + "remoulding": "remolding", + "remoulds": "remolds", + "reorganisation": "reorganization", + "reorganisations": "reorganizations", + "reorganise": "reorganize", + "reorganised": "reorganized", + "reorganises": "reorganizes", + "reorganising": "reorganizing", + "revelled": "reveled", + "reveller": "reveler", + "revellers": "revelers", + "revelling": "reveling", + "revitalise": "revitalize", + "revitalised": "revitalized", + "revitalises": "revitalizes", + "revitalising": "revitalizing", + "revolutionise": "revolutionize", + "revolutionised": "revolutionized", + "revolutionises": "revolutionizes", + "revolutionising": "revolutionizing", + "rhapsodise": "rhapsodize", + "rhapsodised": "rhapsodized", + "rhapsodises": "rhapsodizes", + "rhapsodising": "rhapsodizing", + "rigour": "rigor", + "rigours": "rigors", + "ritualised": "ritualized", + "rivalled": "rivaled", + "rivalling": "rivaling", + "romanticise": "romanticize", + "romanticised": "romanticized", + "romanticises": "romanticizes", + "romanticising": "romanticizing", + "rumour": "rumor", + "rumoured": "rumored", + "rumours": "rumors", + "sabre": "saber", + "sabres": "sabers", + "saltpetre": "saltpeter", + "sanitise": "sanitize", + "sanitised": "sanitized", + "sanitises": "sanitizes", + "sanitising": "sanitizing", + "satirise": "satirize", + "satirised": "satirized", + "satirises": "satirizes", + "satirising": "satirizing", + "saviour": "savior", + "saviours": "saviors", + "savour": "savor", + "savoured": "savored", + "savouries": "savories", + "savouring": "savoring", + "savours": "savors", + "savoury": "savory", + "scandalise": "scandalize", + "scandalised": "scandalized", + "scandalises": "scandalizes", + "scandalising": "scandalizing", + "sceptic": "skeptic", + "sceptical": "skeptical", + "sceptically": "skeptically", + "scepticism": "skepticism", + "sceptics": "skeptics", + "sceptre": "scepter", + "sceptres": "scepters", + "scrutinise": "scrutinize", + "scrutinised": "scrutinized", + "scrutinises": "scrutinizes", + "scrutinising": "scrutinizing", + "secularisation": "secularization", + "secularise": "secularize", + "secularised": "secularized", + "secularises": "secularizes", + "secularising": "secularizing", + "sensationalise": "sensationalize", + "sensationalised": "sensationalized", + "sensationalises": "sensationalizes", + "sensationalising": "sensationalizing", + "sensitise": "sensitize", + "sensitised": "sensitized", + "sensitises": "sensitizes", + "sensitising": "sensitizing", + "sentimentalise": "sentimentalize", + "sentimentalised": "sentimentalized", + "sentimentalises": "sentimentalizes", + "sentimentalising": "sentimentalizing", + "sepulchre": "sepulcher", + "sepulchres": "sepulchers", + "serialisation": "serialization", + "serialisations": "serializations", + "serialise": "serialize", + "serialised": "serialized", + "serialises": "serializes", + "serialising": "serializing", + "sermonise": "sermonize", + "sermonised": "sermonized", + "sermonises": "sermonizes", + "sermonising": "sermonizing", + "sheikh": "sheik", + "shovelled": "shoveled", + "shovelling": "shoveling", + "shrivelled": "shriveled", + "shrivelling": "shriveling", + "signalise": "signalize", + "signalised": "signalized", + "signalises": "signalizes", + "signalising": "signalizing", + "signalled": "signaled", + "signalling": "signaling", + "smoulder": "smolder", + "smouldered": "smoldered", + "smouldering": "smoldering", + "smoulders": "smolders", + "snivelled": "sniveled", + "snivelling": "sniveling", + "snorkelled": "snorkeled", + "snorkelling": "snorkeling", + "snowplough": "snowplow", + "snowploughs": "snowplow", + "socialisation": "socialization", + "socialise": "socialize", + "socialised": "socialized", + "socialises": "socializes", + "socialising": "socializing", + "sodomise": "sodomize", + "sodomised": "sodomized", + "sodomises": "sodomizes", + "sodomising": "sodomizing", + "solemnise": "solemnize", + "solemnised": "solemnized", + "solemnises": "solemnizes", + "solemnising": "solemnizing", + "sombre": "somber", + "specialisation": "specialization", + "specialisations": "specializations", + "specialise": "specialize", + "specialised": "specialized", + "specialises": "specializes", + "specialising": "specializing", + "spectre": "specter", + "spectres": "specters", + "spiralled": "spiraled", + "spiralling": "spiraling", + "splendour": "splendor", + "splendours": "splendors", + "squirrelled": "squirreled", + "squirrelling": "squirreling", + "stabilisation": "stabilization", + "stabilise": "stabilize", + "stabilised": "stabilized", + "stabiliser": "stabilizer", + "stabilisers": "stabilizers", + "stabilises": "stabilizes", + "stabilising": "stabilizing", + "standardisation": "standardization", + "standardise": "standardize", + "standardised": "standardized", + "standardises": "standardizes", + "standardising": "standardizing", + "stencilled": "stenciled", + "stencilling": "stenciling", + "sterilisation": "sterilization", + "sterilisations": "sterilizations", + "sterilise": "sterilize", + "sterilised": "sterilized", + "steriliser": "sterilizer", + "sterilisers": "sterilizers", + "sterilises": "sterilizes", + "sterilising": "sterilizing", + "stigmatisation": "stigmatization", + "stigmatise": "stigmatize", + "stigmatised": "stigmatized", + "stigmatises": "stigmatizes", + "stigmatising": "stigmatizing", + "storey": "story", + "storeys": "stories", + "subsidisation": "subsidization", + "subsidise": "subsidize", + "subsidised": "subsidized", + "subsidiser": "subsidizer", + "subsidisers": "subsidizers", + "subsidises": "subsidizes", + "subsidising": "subsidizing", + "succour": "succor", + "succoured": "succored", + "succouring": "succoring", + "succours": "succors", + "sulphate": "sulfate", + "sulphates": "sulfates", + "sulphide": "sulfide", + "sulphides": "sulfides", + "sulphur": "sulfur", + "sulphurous": "sulfurous", + "summarise": "summarize", + "summarised": "summarized", + "summarises": "summarizes", + "summarising": "summarizing", + "swivelled": "swiveled", + "swivelling": "swiveling", + "symbolise": "symbolize", + "symbolised": "symbolized", + "symbolises": "symbolizes", + "symbolising": "symbolizing", + "sympathise": "sympathize", + "sympathised": "sympathized", + "sympathiser": "sympathizer", + "sympathisers": "sympathizers", + "sympathises": "sympathizes", + "sympathising": "sympathizing", + "synchronisation": "synchronization", + "synchronise": "synchronize", + "synchronised": "synchronized", + "synchronises": "synchronizes", + "synchronising": "synchronizing", + "synthesise": "synthesize", + "synthesised": "synthesized", + "synthesiser": "synthesizer", + "synthesisers": "synthesizers", + "synthesises": "synthesizes", + "synthesising": "synthesizing", + "syphon": "siphon", + "syphoned": "siphoned", + "syphoning": "siphoning", + "syphons": "siphons", + "systematisation": "systematization", + "systematise": "systematize", + "systematised": "systematized", + "systematises": "systematizes", + "systematising": "systematizing", + "tantalise": "tantalize", + "tantalised": "tantalized", + "tantalises": "tantalizes", + "tantalising": "tantalizing", + "tantalisingly": "tantalizingly", + "tasselled": "tasseled", + "technicolour": "technicolor", + "temporise": "temporize", + "temporised": "temporized", + "temporises": "temporizes", + "temporising": "temporizing", + "tenderise": "tenderize", + "tenderised": "tenderized", + "tenderises": "tenderizes", + "tenderising": "tenderizing", + "terrorise": "terrorize", + "terrorised": "terrorized", + "terrorises": "terrorizes", + "terrorising": "terrorizing", + "theatre": "theater", + "theatregoer": "theatergoer", + "theatregoers": "theatergoers", + "theatres": "theaters", + "theorise": "theorize", + "theorised": "theorized", + "theorises": "theorizes", + "theorising": "theorizing", + "tonne": "ton", + "tonnes": "tons", + "towelled": "toweled", + "towelling": "toweling", + "toxaemia": "toxemia", + "tranquillise": "tranquilize", + "tranquillised": "tranquilized", + "tranquilliser": "tranquilizer", + "tranquillisers": "tranquilizers", + "tranquillises": "tranquilizes", + "tranquillising": "tranquilizing", + "tranquillity": "tranquility", + "tranquillize": "tranquilize", + "tranquillized": "tranquilized", + "tranquillizer": "tranquilizer", + "tranquillizers": "tranquilizers", + "tranquillizes": "tranquilizes", + "tranquillizing": "tranquilizing", + "tranquilly": "tranquility", + "transistorised": "transistorized", + "traumatise": "traumatize", + "traumatised": "traumatized", + "traumatises": "traumatizes", + "traumatising": "traumatizing", + "travelled": "traveled", + "traveller": "traveler", + "travellers": "travelers", + "travelling": "traveling", + "travelog": "travelogue", + "travelogs": "travelogues", + "trialled": "trialed", + "trialling": "trialing", + "tricolour": "tricolor", + "tricolours": "tricolors", + "trivialise": "trivialize", + "trivialised": "trivialized", + "trivialises": "trivializes", + "trivialising": "trivializing", + "tumour": "tumor", + "tumours": "tumors", + "tunnelled": "tunneled", + "tunnelling": "tunneling", + "tyrannise": "tyrannize", + "tyrannised": "tyrannized", + "tyrannises": "tyrannizes", + "tyrannising": "tyrannizing", + "tyre": "tire", + "tyres": "tires", + "unauthorised": "unauthorized", + "uncivilised": "uncivilized", + "underutilised": "underutilized", + "unequalled": "unequaled", + "unfavourable": "unfavorable", + "unfavourably": "unfavorably", + "unionisation": "unionization", + "unionise": "unionize", + "unionised": "unionized", + "unionises": "unionizes", + "unionising": "unionizing", + "unorganised": "unorganized", + "unravelled": "unraveled", + "unravelling": "unraveling", + "unrecognisable": "unrecognizable", + "unrecognised": "unrecognized", + "unrivalled": "unrivaled", + "unsavoury": "unsavory", + "untrammelled": "untrammeled", + "urbanisation": "urbanization", + "urbanise": "urbanize", + "urbanised": "urbanized", + "urbanises": "urbanizes", + "urbanising": "urbanizing", + "utilisable": "utilizable", + "utilisation": "utilization", + "utilise": "utilize", + "utilised": "utilized", + "utilises": "utilizes", + "utilising": "utilizing", + "valour": "valor", + "vandalise": "vandalize", + "vandalised": "vandalized", + "vandalises": "vandalizes", + "vandalising": "vandalizing", + "vaporisation": "vaporization", + "vaporise": "vaporize", + "vaporised": "vaporized", + "vaporises": "vaporizes", + "vaporising": "vaporizing", + "vapour": "vapor", + "vapours": "vapors", + "verbalise": "verbalize", + "verbalised": "verbalized", + "verbalises": "verbalizes", + "verbalising": "verbalizing", + "victimisation": "victimization", + "victimise": "victimize", + "victimised": "victimized", + "victimises": "victimizes", + "victimising": "victimizing", + "videodisc": "videodisk", + "videodiscs": "videodisks", + "vigour": "vigor", + "visualisation": "visualization", + "visualisations": "visualizations", + "visualise": "visualize", + "visualised": "visualized", + "visualises": "visualizes", + "visualising": "visualizing", + "vocalisation": "vocalization", + "vocalisations": "vocalizations", + "vocalise": "vocalize", + "vocalised": "vocalized", + "vocalises": "vocalizes", + "vocalising": "vocalizing", + "vulcanised": "vulcanized", + "vulgarisation": "vulgarization", + "vulgarise": "vulgarize", + "vulgarised": "vulgarized", + "vulgarises": "vulgarizes", + "vulgarising": "vulgarizing", + "waggon": "wagon", + "waggons": "wagons", + "watercolour": "watercolor", + "watercolours": "watercolors", + "weaselled": "weaseled", + "weaselling": "weaseling", + "westernisation": "westernization", + "westernise": "westernize", + "westernised": "westernized", + "westernises": "westernizes", + "westernising": "westernizing", + "womanise": "womanize", + "womanised": "womanized", + "womaniser": "womanizer", + "womanisers": "womanizers", + "womanises": "womanizes", + "womanising": "womanizing", + "woollen": "woolen", + "woollens": "woolens", + "woollies": "woolies", + "woolly": "wooly", + "worshipped": "worshiped", + "worshipping": "worshiping", + "worshipper": "worshiper", + "yodelled": "yodeled", + "yodelling": "yodeling", + "yoghourt": "yogurt", + "yoghourts": "yogurts", + "yoghurt": "yogurt", + "yoghurts": "yogurts", + "mhm": "hmm", + "mmm": "hmm" +} \ No newline at end of file diff --git a/MMaDA/eval_ASR_TTS/whisper_asr/normalizers/english.py b/MMaDA/eval_ASR_TTS/whisper_asr/normalizers/english.py new file mode 100644 index 0000000000000000000000000000000000000000..4932042bc5b7e9c3fed75a03af66948e4225a2b0 --- /dev/null +++ b/MMaDA/eval_ASR_TTS/whisper_asr/normalizers/english.py @@ -0,0 +1,550 @@ +import json +import os +import re +from fractions import Fraction +from typing import Iterator, List, Match, Optional, Union + +from more_itertools import windowed + +from .basic import remove_symbols_and_diacritics + + +class EnglishNumberNormalizer: + """ + Convert any spelled-out numbers into arabic numbers, while handling: + + - remove any commas + - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. + - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` + - spell out `one` and `ones` + - interpret successive single-digit numbers as nominal: `one oh one` -> `101` + """ + + def __init__(self): + super().__init__() + + self.zeros = {"o", "oh", "zero"} + self.ones = { + name: i + for i, name in enumerate( + [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ], + start=1, + ) + } + self.ones_plural = { + "sixes" if name == "six" else name + "s": (value, "s") + for name, value in self.ones.items() + } + self.ones_ordinal = { + "zeroth": (0, "th"), + "first": (1, "st"), + "second": (2, "nd"), + "third": (3, "rd"), + "fifth": (5, "th"), + "twelfth": (12, "th"), + **{ + name + ("h" if name.endswith("t") else "th"): (value, "th") + for name, value in self.ones.items() + if value > 3 and value != 5 and value != 12 + }, + } + self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} + + self.tens = { + "twenty": 20, + "thirty": 30, + "forty": 40, + "fifty": 50, + "sixty": 60, + "seventy": 70, + "eighty": 80, + "ninety": 90, + } + self.tens_plural = { + name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() + } + self.tens_ordinal = { + name.replace("y", "ieth"): (value, "th") + for name, value in self.tens.items() + } + self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} + + self.multipliers = { + "hundred": 100, + "thousand": 1_000, + "million": 1_000_000, + "billion": 1_000_000_000, + "trillion": 1_000_000_000_000, + "quadrillion": 1_000_000_000_000_000, + "quintillion": 1_000_000_000_000_000_000, + "sextillion": 1_000_000_000_000_000_000_000, + "septillion": 1_000_000_000_000_000_000_000_000, + "octillion": 1_000_000_000_000_000_000_000_000_000, + "nonillion": 1_000_000_000_000_000_000_000_000_000_000, + "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, + } + self.multipliers_plural = { + name + "s": (value, "s") for name, value in self.multipliers.items() + } + self.multipliers_ordinal = { + name + "th": (value, "th") for name, value in self.multipliers.items() + } + self.multipliers_suffixed = { + **self.multipliers_plural, + **self.multipliers_ordinal, + } + self.decimals = {*self.ones, *self.tens, *self.zeros} + + self.preceding_prefixers = { + "minus": "-", + "negative": "-", + "plus": "+", + "positive": "+", + } + self.following_prefixers = { + "pound": "Ā£", + "pounds": "Ā£", + "euro": "€", + "euros": "€", + "dollar": "$", + "dollars": "$", + "cent": "Ā¢", + "cents": "Ā¢", + } + self.prefixes = set( + list(self.preceding_prefixers.values()) + + list(self.following_prefixers.values()) + ) + self.suffixers = { + "per": {"cent": "%"}, + "percent": "%", + } + self.specials = {"and", "double", "triple", "point"} + + self.words = set( + [ + key + for mapping in [ + self.zeros, + self.ones, + self.ones_suffixed, + self.tens, + self.tens_suffixed, + self.multipliers, + self.multipliers_suffixed, + self.preceding_prefixers, + self.following_prefixers, + self.suffixers, + self.specials, + ] + for key in mapping + ] + ) + self.literal_words = {"one", "ones"} + + def process_words(self, words: List[str]) -> Iterator[str]: + prefix: Optional[str] = None + value: Optional[Union[str, int]] = None + skip = False + + def to_fraction(s: str): + try: + return Fraction(s) + except ValueError: + return None + + def output(result: Union[str, int]): + nonlocal prefix, value + result = str(result) + if prefix is not None: + result = prefix + result + value = None + prefix = None + return result + + if len(words) == 0: + return + + for prev, current, next in windowed([None] + words + [None], 3): + if skip: + skip = False + continue + + next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) + has_prefix = current[0] in self.prefixes + current_without_prefix = current[1:] if has_prefix else current + if re.match(r"^\d+(\.\d+)?$", current_without_prefix): + # arabic numbers (potentially with signs and fractions) + f = to_fraction(current_without_prefix) + assert f is not None + if value is not None: + if isinstance(value, str) and value.endswith("."): + # concatenate decimals / ip address components + value = str(value) + str(current) + continue + else: + yield output(value) + + prefix = current[0] if has_prefix else prefix + if f.denominator == 1: + value = f.numerator # store integers as int + else: + value = current_without_prefix + elif current not in self.words: + # non-numeric words + if value is not None: + yield output(value) + yield output(current) + elif current in self.zeros: + value = str(value or "") + "0" + elif current in self.ones: + ones = self.ones[current] + + if value is None: + value = ones + elif isinstance(value, str) or prev in self.ones: + if ( + prev in self.tens and ones < 10 + ): # replace the last zero with the digit + assert value[-1] == "0" + value = value[:-1] + str(ones) + else: + value = str(value) + str(ones) + elif ones < 10: + if value % 10 == 0: + value += ones + else: + value = str(value) + str(ones) + else: # eleven to nineteen + if value % 100 == 0: + value += ones + else: + value = str(value) + str(ones) + elif current in self.ones_suffixed: + # ordinal or cardinal; yield the number right away + ones, suffix = self.ones_suffixed[current] + if value is None: + yield output(str(ones) + suffix) + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: + assert value[-1] == "0" + yield output(value[:-1] + str(ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + elif ones < 10: + if value % 10 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + else: # eleven to nineteen + if value % 100 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + value = None + elif current in self.tens: + tens = self.tens[current] + if value is None: + value = tens + elif isinstance(value, str): + value = str(value) + str(tens) + else: + if value % 100 == 0: + value += tens + else: + value = str(value) + str(tens) + elif current in self.tens_suffixed: + # ordinal or cardinal; yield the number right away + tens, suffix = self.tens_suffixed[current] + if value is None: + yield output(str(tens) + suffix) + elif isinstance(value, str): + yield output(str(value) + str(tens) + suffix) + else: + if value % 100 == 0: + yield output(str(value + tens) + suffix) + else: + yield output(str(value) + str(tens) + suffix) + elif current in self.multipliers: + multiplier = self.multipliers[current] + if value is None: + value = multiplier + elif isinstance(value, str) or value == 0: + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + value = p.numerator + else: + yield output(value) + value = multiplier + else: + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + elif current in self.multipliers_suffixed: + multiplier, suffix = self.multipliers_suffixed[current] + if value is None: + yield output(str(multiplier) + suffix) + elif isinstance(value, str): + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + yield output(str(p.numerator) + suffix) + else: + yield output(value) + yield output(str(multiplier) + suffix) + else: # int + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + yield output(str(value) + suffix) + value = None + elif current in self.preceding_prefixers: + # apply prefix (positive, minus, etc.) if it precedes a number + if value is not None: + yield output(value) + + if next in self.words or next_is_numeric: + prefix = self.preceding_prefixers[current] + else: + yield output(current) + elif current in self.following_prefixers: + # apply prefix (dollars, cents, etc.) only after a number + if value is not None: + prefix = self.following_prefixers[current] + yield output(value) + else: + yield output(current) + elif current in self.suffixers: + # apply suffix symbols (percent -> '%') + if value is not None: + suffix = self.suffixers[current] + if isinstance(suffix, dict): + if next in suffix: + yield output(str(value) + suffix[next]) + skip = True + else: + yield output(value) + yield output(current) + else: + yield output(str(value) + suffix) + else: + yield output(current) + elif current in self.specials: + if next not in self.words and not next_is_numeric: + # apply special handling only if the next word can be numeric + if value is not None: + yield output(value) + yield output(current) + elif current == "and": + # ignore "and" after hundreds, thousands, etc. + if prev not in self.multipliers: + if value is not None: + yield output(value) + yield output(current) + elif current == "double" or current == "triple": + if next in self.ones or next in self.zeros: + repeats = 2 if current == "double" else 3 + ones = self.ones.get(next, 0) + value = str(value or "") + str(ones) * repeats + skip = True + else: + if value is not None: + yield output(value) + yield output(current) + elif current == "point": + if next in self.decimals or next_is_numeric: + value = str(value or "") + "." + else: + # should all have been covered at this point + raise ValueError(f"Unexpected token: {current}") + else: + # all should have been covered at this point + raise ValueError(f"Unexpected token: {current}") + + if value is not None: + yield output(value) + + def preprocess(self, s: str): + # replace " and a half" with " point five" + results = [] + + segments = re.split(r"\band\s+a\s+half\b", s) + for i, segment in enumerate(segments): + if len(segment.strip()) == 0: + continue + if i == len(segments) - 1: + results.append(segment) + else: + results.append(segment) + last_word = segment.rsplit(maxsplit=2)[-1] + if last_word in self.decimals or last_word in self.multipliers: + results.append("point five") + else: + results.append("and a half") + + s = " ".join(results) + + # put a space at number/letter boundary + s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) + s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) + + # but remove spaces which could be a suffix + s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) + + return s + + def postprocess(self, s: str): + def combine_cents(m: Match): + try: + currency = m.group(1) + integer = m.group(2) + cents = int(m.group(3)) + return f"{currency}{integer}.{cents:02d}" + except ValueError: + return m.string + + def extract_cents(m: Match): + try: + return f"Ā¢{int(m.group(1))}" + except ValueError: + return m.string + + # apply currency postprocessing; "$2 and Ā¢7" -> "$2.07" + s = re.sub(r"([€£$])([0-9]+) (?:and )?Ā¢([0-9]{1,2})\b", combine_cents, s) + s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) + + # write "one(s)" instead of "1(s)", just for the readability + s = re.sub(r"\b1(s?)\b", r"one\1", s) + + return s + + def __call__(self, s: str): + s = self.preprocess(s) + s = " ".join(word for word in self.process_words(s.split()) if word is not None) + s = self.postprocess(s) + + return s + + +class EnglishSpellingNormalizer: + """ + Applies British-American spelling mappings as listed in [1]. + + [1] https://www.tysto.com/uk-us-spelling-list.html + """ + + def __init__(self): + mapping_path = os.path.join(os.path.dirname(__file__), "english.json") + self.mapping = json.load(open(mapping_path)) + + def __call__(self, s: str): + return " ".join(self.mapping.get(word, word) for word in s.split()) + + +class EnglishTextNormalizer: + def __init__(self): + self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" + self.replacers = { + # common contractions + r"\bwon't\b": "will not", + r"\bcan't\b": "can not", + r"\blet's\b": "let us", + r"\bain't\b": "aint", + r"\by'all\b": "you all", + r"\bwanna\b": "want to", + r"\bgotta\b": "got to", + r"\bgonna\b": "going to", + r"\bi'ma\b": "i am going to", + r"\bimma\b": "i am going to", + r"\bwoulda\b": "would have", + r"\bcoulda\b": "could have", + r"\bshoulda\b": "should have", + r"\bma'am\b": "madam", + # contractions in titles/prefixes + r"\bmr\b": "mister ", + r"\bmrs\b": "missus ", + r"\bst\b": "saint ", + r"\bdr\b": "doctor ", + r"\bprof\b": "professor ", + r"\bcapt\b": "captain ", + r"\bgov\b": "governor ", + r"\bald\b": "alderman ", + r"\bgen\b": "general ", + r"\bsen\b": "senator ", + r"\brep\b": "representative ", + r"\bpres\b": "president ", + r"\brev\b": "reverend ", + r"\bhon\b": "honorable ", + r"\basst\b": "assistant ", + r"\bassoc\b": "associate ", + r"\blt\b": "lieutenant ", + r"\bcol\b": "colonel ", + r"\bjr\b": "junior ", + r"\bsr\b": "senior ", + r"\besq\b": "esquire ", + # prefect tenses, ideally it should be any past participles, but it's harder.. + r"'d been\b": " had been", + r"'s been\b": " has been", + r"'d gone\b": " had gone", + r"'s gone\b": " has gone", + r"'d done\b": " had done", # "'s done" is ambiguous + r"'s got\b": " has got", + # general contractions + r"n't\b": " not", + r"'re\b": " are", + r"'s\b": " is", + r"'d\b": " would", + r"'ll\b": " will", + r"'t\b": " not", + r"'ve\b": " have", + r"'m\b": " am", + } + self.standardize_numbers = EnglishNumberNormalizer() + self.standardize_spellings = EnglishSpellingNormalizer() + + def __call__(self, s: str): + s = s.lower() + + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = re.sub(self.ignore_patterns, "", s) + s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe + + for pattern, replacement in self.replacers.items(): + s = re.sub(pattern, replacement, s) + + s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits + s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols + + s = self.standardize_numbers(s) + s = self.standardize_spellings(s) + + # now remove prefix/suffix symbols that are not preceded/followed by numbers + s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) + s = re.sub(r"([^0-9])%", r"\1 ", s) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space + + return s diff --git a/MMaDA/eval_ASR_TTS/whisper_asr/whisper_asr.py b/MMaDA/eval_ASR_TTS/whisper_asr/whisper_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MMaDA/eval_emova.py b/MMaDA/eval_emova.py new file mode 100644 index 0000000000000000000000000000000000000000..d8070d9d54cfbaad10bdb0814ea414b0b205bfe0 --- /dev/null +++ b/MMaDA/eval_emova.py @@ -0,0 +1,249 @@ +# coding=utf-8 +# Copyright 2025 AIDAS Lab +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import logging +import editdistance +from functools import partial + +os.environ["TOKENIZERS_PARALLETISM"] = "true" + +from tqdm import tqdm +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +import wandb +from datasets import load_dataset +from transformers import AutoModel, AutoProcessor + +# --- Helper Functions (from your reference script) --- + +def setup_logger(rank): + """Sets up a logger for each DDP process.""" + logger = logging.getLogger(__name__) + if logger.hasHandlers(): + logger.handlers.clear() + + formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + + logger.setLevel(logging.INFO if rank == 0 else logging.WARNING) + return logger + +def calculate_WER(recognized_text_list, groundtruth_text_list): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + word_num, scores = 0.0, 0.0 + for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): + recognized_text = re.sub(r"[^\w\s']", "", recognized_text.lower()) + groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text.lower()) + + recognized_word_list = recognized_text.split() + groundtruth_word_list = groundtruth_text.split() + + current_word_num = len(groundtruth_word_list) + word_num += current_word_num + + scores += editdistance.eval(recognized_word_list, groundtruth_word_list) + + WER = scores / word_num if word_num > 0 else 0.0 + return WER, scores, word_num + +def get_librispeech_dataset(logger, split="test.clean"): + """Loads the Librispeech ASR dataset from Hugging Face.""" + logger.info(f"Loading librispeech_asr dataset ({split})...") + dataset = load_dataset("librispeech_asr", split=split, trust_remote_code=True) + logger.info("Dataset loaded successfully.") + return dataset + +def setup_distributed(rank, world_size): + """Initializes the distributed process group.""" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup_distributed(): + """Cleans up the distributed process group.""" + dist.destroy_process_group() + +# --- Custom Dataset and Collate Function for EMOVA --- + +class LibrispeechAudioDataset(Dataset): + """A simple dataset that returns audio file path and ground truth text.""" + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return { + "audio_path": example['file'], + "gt_text": example['text'], + "sample_id": example['id'] + } + +class EmovaS2TCollateFn: + """ + Collate function to prepare batches for the EMOVA model using its processor. + """ + def __init__(self, processor): + self.processor = processor + self.prompt_text = "Transcribe the given audio." + + def __call__(self, batch): + audio_paths = [item["audio_path"] for item in batch] + gt_texts = [item["gt_text"] for item in batch] + sample_ids = [item["sample_id"] for item in batch] + + # Construct the text input for each audio file in the batch + text_inputs = [ + [ + {"role": "user", "content": [{"type": "audio"}, {"type": "text", "text": self.prompt_text}]} + ] + for _ in audio_paths + ] + + # Use the EMOVA processor to prepare the multimodal batch + inputs = self.processor( + text=text_inputs, + audios=audio_paths, + return_tensors="pt", + padding=True + ) + + inputs['gt_texts'] = gt_texts + inputs['sample_ids'] = sample_ids + return inputs + +def main(): + """Main function to run the distributed evaluation.""" + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + setup_distributed(rank, world_size) + device = torch.device(f"cuda:{rank}") + logger = setup_logger(rank) + + if rank == 0: + wandb.init(project="emova-librispeech-eval") + + # --- 1. Load EMOVA Models and Processors --- + logger.info("Loading EMOVA models and processors...") + model_name = "Emova-ollm/emova-qwen-2-5-7b-hf" + + model = AutoModel.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + attn_implementation='flash_attention_2', + low_cpu_mem_usage=True, + trust_remote_code=True + ).to(device) + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + speech_tokenizer = AutoModel.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf", + torch_dtype=torch.float32, + trust_remote_code=True + ).to(device).eval() + + processor.set_speech_tokenizer(speech_tokenizer) + + # Wrap the main model with DDP + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + logger.info("āœ… Models loaded and wrapped with DDP successfully!") + + # --- 2. Setup DataLoader --- + hf_dataset = get_librispeech_dataset(logger, split="test.clean") + eval_dataset = LibrispeechAudioDataset(hf_dataset) + sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) + + collate_fn = EmovaS2TCollateFn(processor) + + dataloader = DataLoader( + eval_dataset, + batch_size=4, # Adjust batch size based on your GPU memory + sampler=sampler, + num_workers=4, + collate_fn=collate_fn, + pin_memory=True + ) + + # --- 3. Evaluation Loop --- + local_results = [] + model.eval() + + progress_bar = tqdm(dataloader, desc="Evaluating on Librispeech", disable=(rank != 0)) + for batch in progress_bar: + gt_texts = batch.pop("gt_texts") + sample_ids = batch.pop("sample_ids") + + # Move batch tensors to the correct device + inputs = {k: v.to(device) for k, v in batch.items()} + + with torch.no_grad(): + outputs = model.module.generate(**inputs, max_new_tokens=256, do_sample=False) + # Slice to get only the generated tokens + generated_ids = outputs[:, inputs['input_ids'].shape[1]:] + decoded_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) + + for i in range(len(decoded_texts)): + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": decoded_texts[i].strip() + }) + + if rank == 0 and i == 0 and len(local_results) % 10 == 1: # Log sample every 10 batches on rank 0 + logger.info(f"\n--- Sample ---") + logger.info(f" ID: {sample_ids[i]}") + logger.info(f" GT: {gt_texts[i]}") + logger.info(f" PD: {decoded_texts[i].strip()}") + logger.info(f"----------------") + + # --- 4. Gather Results and Calculate Final Score --- + all_results = [None] * world_size + dist.all_gather_object(all_results, local_results) + + if rank == 0: + logger.info("Gathering and processing results from all GPUs...") + final_results = [item for sublist in all_results for item in sublist] + + gt_list = [res["gt_text"] for res in final_results] + pred_list = [res["decoded_text"] for res in final_results] + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for res in final_results: + results_table.add_data(res["sample_id"], res["gt_text"], res["decoded_text"]) + wandb.log({"S2T Predictions": results_table}) + + wer, errors, words = calculate_WER(pred_list, gt_list) + logger.info(f"Final WER (Librispeech test.clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + wandb.log({"WER": wer, "Total Word Errors": errors, "Total Words": words}) + + # --- Cleanup --- + if rank == 0: + wandb.finish() + cleanup_distributed() + +if __name__ == '__main__': + # Set master address and port for DDP + # os.environ['MASTER_ADDR'] = 'localhost' + # os.environ['MASTER_PORT'] = '12355' + main() \ No newline at end of file diff --git a/MMaDA/generate.py b/MMaDA/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..34d80eeefa5452a2971b5e87b2448e76641bc718 --- /dev/null +++ b/MMaDA/generate.py @@ -0,0 +1,146 @@ +import torch +import numpy as np +import torch.nn.functional as F + +from transformers import AutoTokenizer, AutoModel +from models import MMadaModelLM + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + + +@ torch.no_grad() +def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., + cfg_scale=0., remasking='low_confidence', mask_id=126336, attention_mask=None): + ''' + Args: + model: Mask predictor. + prompt: A tensor of shape (B, L), where B is batch size. + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + ''' + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + batch_size = prompt.shape[0] + x = torch.full((batch_size, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x[:, :prompt.shape[1]] = prompt.clone() + + prompt_index = (x != mask_id) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = model(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = model(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + # print(confidence.shape) + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + return x + + +def main(): + device = 'cuda' + # Load from HF + + # model = MMadaModelLM.from_pretrained("Gen-Verse/MMaDA-8B-Base", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() + # tokenizer = AutoTokenizer.from_pretrained("Gen-Verse/MMaDA-8B-Base", trust_remote_code=True) + + train_step = 135000 + trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{train_step}/unwrapped_model/" + + model = MMadaModelLM.from_pretrained( + trained_checkpoint_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" + ).to(device) + + tokenizer = AutoTokenizer.from_pretrained("Gen-Verse/MMaDA-8B-MixCoT", trust_remote_code=True) + + tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}" + prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" + m = [{"role": "user", "content": prompt}, ] + prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) + input_ids = tokenizer(text=prompt, return_tensors="pt", padding=True, padding_side="left")['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + out = generate(model, input_ids, steps=128, gen_length=128, block_length=128, temperature=1, cfg_scale=0., remasking='low_confidence') + print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)) + + +if __name__ == '__main__': + main() diff --git a/MMaDA/generate_mmada.py b/MMaDA/generate_mmada.py new file mode 100644 index 0000000000000000000000000000000000000000..c72da8d9295f79c227eb040fd50d176ac3a1c69e --- /dev/null +++ b/MMaDA/generate_mmada.py @@ -0,0 +1,132 @@ +import torch +import numpy as np +import torch.nn.functional as F + +from transformers import AutoTokenizer, AutoModel +from models import MMadaModelLM + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + + +@ torch.no_grad() +def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., + cfg_scale=0., remasking='low_confidence', mask_id=126336, attention_mask=None): + ''' + Args: + model: Mask predictor. + prompt: A tensor of shape (B, L), where B is batch size. + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + ''' + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + batch_size = prompt.shape[0] + x = torch.full((batch_size, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x[:, :prompt.shape[1]] = prompt.clone() + + prompt_index = (x != mask_id) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = model(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = model(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + # print(confidence.shape) + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + return x + + +def main(): + device = 'cuda' + # Load from HF + model = MMadaModelLM.from_pretrained("Gen-Verse/MMaDA-8B-Base", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() + tokenizer = AutoTokenizer.from_pretrained("Gen-Verse/MMaDA-8B-Base", trust_remote_code=True) + tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}" + prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" + m = [{"role": "user", "content": prompt}, ] + prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) + input_ids = tokenizer(text=prompt, return_tensors="pt", padding=True, padding_side="left")['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + out = generate(model, input_ids, steps=128, gen_length=128, block_length=128, temperature=1, cfg_scale=0., remasking='low_confidence') + print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)) + + +if __name__ == '__main__': + main() diff --git a/MMaDA/inference/common.py b/MMaDA/inference/common.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3b72f45d0dee8c339a438d75458edf21df2711 --- /dev/null +++ b/MMaDA/inference/common.py @@ -0,0 +1,154 @@ +import os +import itertools +import json +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +import torch +import wandb +from omegaconf import OmegaConf + +from transformers import AutoTokenizer + +from models import MAGVITv2, OMadaModelLM +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from training.prompting_utils import UniversalPrompting + + +def load_train_config(path: str): + cfg = OmegaConf.load(path) + return cfg + + +def get_vq_model_image(cfg, device): + vq_cfg = cfg.model.vq_model_image + if getattr(vq_cfg, "pretrained_model_path", None): + model = MAGVITv2().to(device) + state_dict = torch.load(vq_cfg.pretrained_model_path)["model"] + model.load_state_dict(state_dict) + return model.eval() + else: + return MAGVITv2.from_pretrained(vq_cfg.vq_model_name).to(device).eval() + + +def get_vq_model_audio(cfg, device): + vq_cfg = cfg.model.vq_model_audio + # Always EMOVA for now + model = EMOVASpeechTokenizer.from_pretrained(vq_cfg.vq_model_name) + model = model.to(device) + model.eval() + return model + + +def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]: + tokenizer = AutoTokenizer.from_pretrained(cfg.model.omada.tokenizer_path, padding_side="left") + uni_prompting = UniversalPrompting( + tokenizer, + max_text_len=cfg.dataset.preprocessing.max_seq_length, + max_audio_len=cfg.dataset.preprocessing.max_aud_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + "<|i2i|>", "<|v2s|>", "<|s2s|>", + "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, + cond_dropout_prob=cfg.training.cond_dropout_prob, + use_reserved_token=True, + ) + return uni_prompting, tokenizer + + +def load_omada_from_checkpoint(ckpt_unwrapped_dir: str, device, torch_dtype=torch.bfloat16) -> OMadaModelLM: + """Load OMada model weights from an `unwrapped_model` directory. + + The helper used to rely on a hard-coded config path which broke when + evaluating checkpoints from other training steps. We now detect the + config.json co-located with the weights so any checkpoint exported by the + trainer can be used directly. + """ + + ckpt_path = Path(ckpt_unwrapped_dir) + if not ckpt_path.is_dir(): + raise FileNotFoundError(f"Expected an 'unwrapped_model' directory, got {ckpt_unwrapped_dir}") + + config_path = ckpt_path / "config.json" + config_arg = str(config_path) if config_path.exists() else None + + model = OMadaModelLM.from_pretrained( + ckpt_unwrapped_dir, + torch_dtype=torch_dtype, + config=config_arg, + trust_remote_code=True, + ).to(device) + model.eval() + return model + + +def list_checkpoints(ckpt_root: str) -> List[str]: + """Return a sorted list of checkpoint 'unwrapped_model' dirs under a training output dir or a direct ckpt dir. + Accepts either: + - A path that already ends with 'unwrapped_model' + - A path to 'checkpoint-XXXX' (we append 'unwrapped_model') + - A path to the experiment output dir that contains many 'checkpoint-*' + """ + p = Path(ckpt_root) + if p.name == "unwrapped_model" and p.is_dir(): + return [str(p)] + if p.name.startswith("checkpoint-") and p.is_dir(): + inner = p / "unwrapped_model" + return [str(inner)] if inner.is_dir() else [] + # otherwise, collect children checkpoints + outs = [] + for child in p.iterdir(): + if child.is_dir() and child.name.startswith("checkpoint-"): + inner = child / "unwrapped_model" + if inner.is_dir(): + outs.append(str(inner)) + # sort by numeric step if possible + def step_key(s: str): + try: + return int(Path(s).parent.name.split("-")[-1]) + except Exception: + return -1 + outs.sort(key=step_key) + return outs + + +def grid_dict(product_space: Dict[str, Iterable]) -> List[Dict]: + """Expand a dict of lists to a list of dict combinations. + Example: {a:[1,2], b:["x"]} -> [{a:1,b:"x"},{a:2,b:"x"}] + """ + keys = list(product_space.keys()) + values = [list(v if isinstance(v, (list, tuple)) else [v]) for v in product_space.values()] + combos = [] + for vals in itertools.product(*values): + combos.append({k: v for k, v in zip(keys, vals)}) + return combos + + +def init_wandb(infer_cfg: Dict, task: str, ckpt_path: str, hparams: Dict): + wcfg = infer_cfg.get("wandb", {}) + project = wcfg.get("project", f"omada-inference-{task}") + entity = wcfg.get("entity") + group = wcfg.get("group", f"{task}") + name_prefix = wcfg.get("name_prefix", f"{task}") + step_str = Path(ckpt_path).parent.name + run_name = f"{name_prefix}-{step_str}-" + ",".join([f"{k}={v}" for k, v in hparams.items()]) + tags = wcfg.get("tags", []) + + wandb.init(project=project, entity=entity, group=group, name=run_name, tags=tags, config={ + "task": task, + "checkpoint": ckpt_path, + "hparams": hparams, + }) + + +def safe_log_table(name: str, columns: List[str], rows: List[List]): + try: + table = wandb.Table(columns=columns) + for r in rows: + table.add_data(*r) + wandb.log({name: table}) + except Exception: + pass diff --git a/MMaDA/inference/configs/s2t_infer_example.yaml b/MMaDA/inference/configs/s2t_infer_example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17b8c0a7d3d1415b191a6f9598c3f92ed5cc0531 --- /dev/null +++ b/MMaDA/inference/configs/s2t_infer_example.yaml @@ -0,0 +1,24 @@ +wandb: + project: omada-inference-s2t + group: s2t-grid + name_prefix: s2t + tags: [s2t, inference] + +# Optional: list one or multiple checkpoint roots. +# If omitted, the script will expand from --ckpt_root argument. +# checkpoints: +# - ckpts/omada/omada-training-stage1_3rd + +generation: + steps: [128, 256] + block_length: [64] + max_new_tokens: [256] + remasking: [low_confidence] + batch_size: [1] + +dataset: + subset: clean + split: test + root_path: /home/work/AIDAS/data/audio/LibriSpeech/test-clean + limit: 64 + diff --git a/MMaDA/inference/configs/t2s_infer_example.yaml b/MMaDA/inference/configs/t2s_infer_example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19cda7f62966e99a92e2354895a3bb965bfbb549 --- /dev/null +++ b/MMaDA/inference/configs/t2s_infer_example.yaml @@ -0,0 +1,20 @@ +wandb: + project: omada-inference-t2s + group: t2s-grid + name_prefix: t2s + tags: [t2s, inference] + +generation: + mode: [fixed] + guidance_scale: [1.5] + temperature: [1.0] + timesteps: [24] + seq_len: [254] + batch_size: [4] + output_dir: [outputs/t2s] + +dataset: + subset: clean + split: test + limit: 16 + diff --git a/MMaDA/inference/configs/v2t_infer_example.yaml b/MMaDA/inference/configs/v2t_infer_example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..09edb124b178233c99e4c54d6597afe40f7436a6 --- /dev/null +++ b/MMaDA/inference/configs/v2t_infer_example.yaml @@ -0,0 +1,17 @@ +wandb: + project: omada-inference-v2t + group: v2t-grid + name_prefix: v2t + tags: [v2t, inference] + +generation: + steps: [256] + block_length: [128] + max_new_tokens: [256] + +dataset: + video_dir: /home/work/AIDAS/video/demo + questions: + - "Please provide a detailed description of the video." + - "Summarize the events in the video." + diff --git a/MMaDA/inference/gradio_multimodal_demo2.py b/MMaDA/inference/gradio_multimodal_demo2.py new file mode 100644 index 0000000000000000000000000000000000000000..c962a5a832a4f9371f9b2c6cfadea88e2d160711 --- /dev/null +++ b/MMaDA/inference/gradio_multimodal_demo2.py @@ -0,0 +1,1172 @@ +#!/usr/bin/env python3 +""" +Gradio demo for OMada Stage 1.3 checkpoints covering: + * Text-to-Speech (T2S) + * Speech-to-Text (S2T) + * Video-to-Text (V2T) + +The implementation wraps the existing CLI inference helpers so that a single +checkpoint directory (…/checkpoint-XXXX/unwrapped_model) can be previewed +interactively. Usage: + + python MMaDA/inference/gradio_multimodal_demo2.py --train-config MMaDA/configs/omada_pretraining_stage1-3.yaml --checkpoint ../ckpt/checkpoint-315000/unwrapped_model/ --share + +If you need remote access, pass `--share` which simply forwards the flag to +`gradio.Blocks.launch`. +""" + +import argparse +import base64 +import html +import io +import os +import random +import sys +import tempfile +import wave +from pathlib import Path +import shutil +import time +from typing import Any, Optional, Tuple +import numpy as np + + +CUSTOM_CSS = """ +:root { + --omada-primary: #1e3a8a; + --omada-accent: #1d4ed8; + --omada-surface: #f3f4f6; + --omada-surface-alt: #ffffff; + --omada-border: #d0d7e5; + --omada-text-primary: #111827; + --omada-text-muted: #374151; + color-scheme: light; +} +html, body, body.dark, html.dark { + background: var(--omada-surface) !important; + color: var(--omada-text-primary) !important; +} +.gradio-container { + background: var(--omada-surface); + color: var(--omada-text-primary); +} +.omada-page-heading { + margin-bottom: 0; + color: var(--omada-text-primary); +} +.omada-tab-intro p { + font-size: 0.95rem; + color: var(--omada-text-primary); + margin-top: 0; + opacity: 0.9; +} +.omada-card { + background: var(--omada-surface-alt); + border-radius: 16px !important; + padding: 18px !important; + box-shadow: none; + border: 1px solid var(--omada-border); + color: var(--omada-text-primary); +} +.omada-card .gradio-slider .wrap-inner { + gap: 6px; +} +.omada-card .gradio-slider input[type=range]::-webkit-slider-thumb { + background: var(--omada-primary); +} +.gradio-slider input[type=range]::-webkit-slider-runnable-track { + background: rgba(14, 33, 80, 0.2); +} +.omada-section-title p { + text-transform: uppercase; + font-size: 0.78rem; + letter-spacing: 0.14em; + color: rgba(30, 58, 138, 0.85); + margin: 0 0 12px 0; +} +.omada-output .gradio-audio, .omada-output .gradio-textbox { + margin-top: 12px; +} +.gradio-textbox, .gradio-dropdown, .gradio-slider, .gradio-audio { + color: var(--omada-text-primary) !important; +} +.gradio-dropdown .wrap .label, .gradio-textbox label, .gradio-slider label { + color: var(--omada-text-primary); +} +.gradio-dropdown .single-select, .gradio-textbox textarea { + background: #ffffff !important; + border: 1px solid var(--omada-border) !important; + color: var(--omada-text-primary) !important; +} +.gradio-dropdown .single-select select, .gradio-textbox textarea { + color: var(--omada-text-primary) !important; +} +.gradio-textbox textarea::placeholder { + color: rgba(148, 163, 184, 0.65); +} +.gradio-dropdown, .gradio-textbox, .gradio-audio, .gradio-video, .gradio-slider { + background: #ffffff !important; + border: 1px solid var(--omada-border) !important; + border-radius: 12px !important; +} +.full-width-button button { + width: 100%; + background: var(--omada-primary) !important; + color: white !important; + border: none !important; + font-weight: 600; + transition: transform 0.2s ease, box-shadow 0.2s ease; + box-shadow: 0 12px 30px -12px rgba(79, 70, 229, 0.65); +} +.full-width-button button:hover { + transform: translateY(-1px); + box-shadow: 0 18px 34px -14px rgba(79, 70, 229, 0.75); +} +.omada-advanced .gr-accordion-header { + font-size: 0.85rem; + letter-spacing: 0.05em; + color: var(--omada-text-muted); +} +.omada-advanced .gr-accordion { + border: 1px solid var(--omada-border); + border-radius: 12px; + background: #ffffff; +} +.gradio-tabs { + background: transparent; +} +.gradio-tabs ul.tab-list { + background: transparent; + border-bottom: 1px solid var(--omada-border); +} +.gradio-tabs button { + color: var(--omada-text-primary); +} +.gradio-tabs button.selected { + color: var(--omada-text-primary); + background: rgba(14, 33, 80, 0.1); + border-bottom: 2px solid var(--omada-primary); +} +.gradio-container .label { + background: rgba(30, 58, 138, 0.1) !important; + color: var(--omada-primary) !important; + border: 1px solid rgba(30, 58, 138, 0.25) !important; + border-radius: 999px !important; + padding: 4px 12px !important; +} +.gradio-button.primary { + background: var(--omada-primary) !important; + color: #ffffff !important; + border: 1px solid var(--omada-primary) !important; +} +.gradio-accordion { + box-shadow: none; +} +.omada-layout { + gap: 20px !important; +} +.omada-chat-column { + gap: 12px !important; +} +.omada-chat-column .gradio-chatbot { + border-radius: 16px; + box-shadow: none; + border: 1px solid var(--omada-border); + background: #ffffff; +} +.omada-controls { + gap: 16px !important; +} +.omada-mode-panel { + display: flex; + flex-direction: column; + gap: 16px !important; +} +.omada-examples-card { + padding-top: 10px !important; +} +.omada-output-panel .gradio-audio, +.omada-output-panel .gradio-textbox { + margin-top: 8px; +} +.omada-response-container { + display: flex; + flex-direction: column; + gap: 10px; +} +.omada-response-status { + margin: 0; + font-weight: 600; + font-size: 0.95rem; + color: var(--omada-text-primary); +} +.omada-response-block, +.omada-audio-block { + background: rgba(30, 58, 138, 0.05); + border-radius: 12px; + padding: 12px 14px; + color: var(--omada-text-primary); + white-space: pre-wrap; + word-break: break-word; +} +.omada-audio-block audio { + width: 100%; +} +.omada-header { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + gap: 18px !important; + margin-bottom: 18px; + text-align: center; +} +.omada-header .gradio-image { + background: transparent !important; + border: none !important; +} +.omada-header img { + object-fit: contain; + display: block; +} +.omada-logo { + max-width: 180px; + padding: 0 !important; +} +.omada-examples { + margin-top: 8px; + padding-top: 4px; +} +.omada-logo .gradio-image, +.omada-logo .gradio-image > div, + .omada-logo .gradio-image .container { + background: transparent !important; + border: none !important; + box-shadow: none !important; + padding: 0 !important; + display: flex; + justify-content: center; + align-items: center; + } + .omada-logo img { + width: 100%; + height: auto; + } +.omada-logo button { + display: none !important; +} +.gradio-container .gradio-component, +.gradio-container .gradio-panel, +.gradio-container .gradio-box { + background: transparent !important; + color: var(--omada-text-primary); +} +.dark .gradio-container, +.dark .gradio-interface, +.dark .gradio-container * { + background-color: inherit !important; + color: var(--omada-text-primary) !important; +} +.dark .gradio-container .gradio-chatbot, +.dark .gradio-container .gradio-dropdown, +.dark .gradio-container .gradio-textbox, +.dark .gradio-container .gradio-audio, +.dark .gradio-container .gradio-video, +.dark .gradio-container .gradio-slider, +.dark .gradio-container .gradio-accordion, +.dark .gradio-container .gradio-panel, +.dark .gradio-container .gradio-box { + background: #ffffff !important; + border-color: var(--omada-border) !important; +} +.omada-title h2 { + font-size: 2.4rem; + font-weight: 700; + color: var(--omada-text-primary); + margin: 0; +} +.omada-title h3 { + font-size: 1.25rem; + font-weight: 600; + letter-spacing: 0.1em; + text-transform: uppercase; + color: var(--omada-text-muted); + margin: 6px 0 0; +} +.omada-tagline p { + color: var(--omada-text-primary); + font-size: 1rem; + margin: 0; + opacity: 0.9; +} +.gradio-container .prose :where(h1, h2, h3, h4, h5, h6) { + color: var(--omada-text-primary) !important; +} +.gradio-container .prose :where(p, li) { + color: var(--omada-text-muted) !important; +} +.gradio-container label, .gradio-container span, .gradio-container button { + color: var(--omada-text-primary); +} +.gradio-container .dark { + background: #ffffff !important; + color: var(--omada-text-primary) !important; +} +.omada-logo-img { + max-width: 250px; + width: 100%; + height: auto; + display: block; + margin: 0 auto; +} +.omada-logo-wrapper { + display: flex; + justify-content: center; + align-items: center; +} +""" + +DEMO_ROOT = Path(__file__).resolve().parent / "demo" +LOGO_PATH = DEMO_ROOT / "logo.png" +T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt" + + +def _load_logo_data() -> Optional[str]: + if not LOGO_PATH.exists(): + return None + try: + import base64 + except ImportError: + return str(LOGO_PATH) + try: + encoded = base64.b64encode(LOGO_PATH.read_bytes()).decode("utf-8") + except OSError: + return str(LOGO_PATH) + return f"data:image/png;base64,{encoded}" + + +def _load_t2s_examples(): + if not T2S_TEXT_PATH.exists(): + return [] + lines = [ + line.strip() + for line in T2S_TEXT_PATH.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + return [[line] for line in lines] + + +def _load_media_examples(subdir: str, suffixes): + target_dir = DEMO_ROOT / subdir + if not target_dir.exists(): + return [] + examples = [] + for path in sorted(target_dir.iterdir()): + if path.is_file() and path.suffix.lower() in suffixes: + examples.append([str(path)]) + return examples + + +T2S_EXAMPLES = _load_t2s_examples() +S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"}) +V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) +LOGO_DATA_URI = _load_logo_data() + + +def _render_response(status: str, body_html: str = "") -> str: + safe_status = html.escape(status or "") + parts = [] + if safe_status: + parts.append(f"

{safe_status}

") + if body_html: + parts.append(body_html) + content = "".join(parts) + return f"
{content}
" + + +def _render_text_message(status: str, content: Optional[str]) -> str: + content = (content or "").strip() + if not content: + return _render_response(status) + safe_content = html.escape(content).replace("\n", "
") + body = f"
{safe_content}
" + return _render_response(status, body) + + +def _render_audio_message(status: str, audio: Optional[Tuple[int, np.ndarray]]) -> str: + """Render an inline HTML audio player for chat responses.""" + + if not audio: + return _render_response(status) + + sample_rate, data = audio + if data is None: + return _render_response(status) + + waveform = np.asarray(data, dtype=np.float32) + if waveform.size == 0: + return _render_response(status) + + if waveform.ndim == 1: + waveform = waveform[:, None] + + channels = waveform.shape[1] + clipped = np.clip(waveform, -1.0, 1.0) + pcm16 = (clipped * 32767.0).astype(np.int16) + + buffer = io.BytesIO() + with wave.open(buffer, "wb") as wav_writer: + wav_writer.setnchannels(channels) + wav_writer.setsampwidth(2) # 16-bit PCM + wav_writer.setframerate(int(sample_rate)) + wav_writer.writeframes(pcm16.tobytes()) + + encoded = base64.b64encode(buffer.getvalue()).decode("ascii") + audio_tag = ( + "
" + "" + "
" + ) + return _render_response(status, audio_tag) + + +def _format_user_message(message: str) -> str: + clean = html.escape(message or "") + return clean.replace("\n", "
") + +# Ensure project modules (models, training, inference.common, …) are importable when the +# script is launched directly via `python MMaDA/inference/gradio_multimodal_demo.py`. +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +import cv2 +import gradio as gr +import numpy as np +import torch +from PIL import Image + +from inference.common import ( + build_uni_prompting, + get_vq_model_audio, + get_vq_model_image, + load_omada_from_checkpoint, + load_train_config, +) +from models import get_mask_schedule +from training.data import S2T_INSTRUCTION, T2S_INSTRUCTION, V2T_INSTRUCTION +from training.utils import image_transform + + +def _resolve_noise_schedule(train_cfg) -> callable: + """Return the diffusion noise schedule used for T2S sampling.""" + + schedule_cfg = getattr(train_cfg, "mask_schedule", None) + if schedule_cfg and hasattr(schedule_cfg, "schedule"): + schedule_name = schedule_cfg.schedule + schedule_kwargs = schedule_cfg.get("params", {}) + return get_mask_schedule(schedule_name, **schedule_kwargs) + + schedule_name = train_cfg.training.get("mask_schedule", "cosine") + return get_mask_schedule(schedule_name) + + +class OmadaDemo: + """Lightweight container that loads all inference assets once.""" + + def __init__(self, train_config: str, checkpoint: str, device: Optional[str] = None): + ckpt_path = Path(checkpoint) + if ckpt_path.name != "unwrapped_model": + raise ValueError( + "`--checkpoint` must point to an `unwrapped_model` directory. " + f"Received: {checkpoint}" + ) + + self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) + self.train_cfg = load_train_config(train_config) + self.uni_prompting, _ = build_uni_prompting(self.train_cfg) + + # Core models + self.model = load_omada_from_checkpoint(str(ckpt_path), self.device) + self.vq_audio = get_vq_model_audio(self.train_cfg, self.device) + self.vq_image = get_vq_model_image(self.train_cfg, self.device) + + self.model.eval() + self.vq_audio.eval() + self.vq_image.eval() + + # Cached constants + self.mask_token_id = int(self.model.config.mask_token_id) + self.noise_schedule = _resolve_noise_schedule(self.train_cfg) + self.sample_rate = int(getattr(self.vq_audio.u2s_config.data, "sampling_rate", 22050)) + + self.genders = ['female', 'male'] + self.emotions = ['angry', 'happy', 'neutral', 'sad'] + self.speeds = ['normal', 'fast', 'slow'] + self.pitches = ['normal', 'high', 'low'] + + # Pre-computed offsets reused across calls + self.text_vocab_size = len(self.uni_prompting.text_tokenizer) + self.codebook_size = int(getattr(self.train_cfg.model.omada, "codebook_size", 8192)) + self.speech_codebook = self.codebook_size + self._temp_video_files = [] + + # ------------------------------------------------------------------ + # Text-to-Speech + # ------------------------------------------------------------------ + def run_t2s( + self, + text: str, + max_new_tokens: int, + steps: int, + block_length: int, + temperature: float, + cfg_scale: float, + gender_choice: str, + emotion_choice: str, + speed_choice: str, + pitch_choice: str, + ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: + + if text is None or not text.strip(): + return None, "Please provide text to synthesize." + + speech_len = int(max_new_tokens) + if speech_len <= 0: + return None, "Speech token length must be positive." + + gender = self._resolve_choice(gender_choice, self.genders) + emotion = self._resolve_choice(emotion_choice, self.emotions) + speed = self._resolve_choice(speed_choice, self.speeds) + pitch = self._resolve_choice(pitch_choice, self.pitches) + + text = text.strip().upper() + prompt = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{random.choice(T2S_INSTRUCTION)}\n{text}" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + + audio_tokens = torch.full( + (1, speech_len), + fill_value=self.mask_token_id, + dtype=torch.long, + device=self.device, + ) + + input_ids, attention_mask = self.uni_prompting(([prompt], audio_tokens), "t2s_gen") + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + with torch.no_grad(): + outputs = self.model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=int(max_new_tokens), + steps=int(steps), + block_length=int(block_length), + temperature=float(temperature), + cfg_scale=float(cfg_scale), + mask_token_id=self.mask_token_id, + attention_mask=attention_mask, + uni_prompting=self.uni_prompting, + codebook_size=self.codebook_size, + ) + + if not outputs: + return None, "Generation produced no speech tokens." + + rel = outputs[0] + if isinstance(rel, torch.Tensor): + rel_ids = rel.detach().cpu().tolist() + else: + rel_ids = list(rel) + + if not rel_ids: + return None, "Generation produced no speech tokens." + + speech_units = "".join(f"<|speech_{sid}|>" for sid in rel_ids) + condition = f"gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}" + wav = self.vq_audio.decode( + speech_units, + condition=condition, + output_wav_file=os.path.join("/tmp", "omada_t2s.wav"), + ) + + audio = (self.sample_rate, wav.astype(np.float32)) + status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})." + return audio, status + + # ------------------------------------------------------------------ + # Speech-to-Text + # ------------------------------------------------------------------ + def run_s2t( + self, + audio_path: Optional[str], + steps: int, + block_length: int, + max_new_tokens: int, + remasking: str, + ) -> Tuple[str, str]: + + if not audio_path: + return "", "Please upload an audio file first." + + tokens = self.vq_audio.encode(audio_path).to(self.device) + offset = self.text_vocab_size + self.speech_codebook + tokens = tokens + offset + + spt = self.uni_prompting.sptids_dict + audio_block = torch.cat( + [ + spt['<|s2t|>'].to(self.device).unsqueeze(0), + spt['<|soa|>'].to(self.device).unsqueeze(0), + tokens.to(self.device), + spt['<|eoa|>'].to(self.device).unsqueeze(0), + ], + dim=1, + ) + + prompt_text = random.choice(S2T_INSTRUCTION) + chat_prompt = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt_text}" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + prompt_tensor = self.uni_prompting.text_tokenizer( + chat_prompt, + return_tensors="pt", + ).input_ids.to(self.device) + + input_ids = torch.cat([audio_block, prompt_tensor], dim=1) + + with torch.no_grad(): + output_ids = self.model.mmu_generate( + input_ids, + max_new_tokens=int(max_new_tokens), + steps=int(steps), + block_length=int(block_length), + remasking=str(remasking), + ) + + decoded = self.uni_prompting.text_tokenizer.batch_decode( + output_ids[:, input_ids.shape[1]:], + skip_special_tokens=True, + )[0] + + return decoded.strip(), "Transcription generated successfully." + + # ------------------------------------------------------------------ + # Video-to-Text + # ------------------------------------------------------------------ + def run_v2t( + self, + video_path: Any, + steps: int, + block_length: int, + max_new_tokens: int, + ) -> Tuple[str, str]: + + resolved_path, converted = self._prepare_video_path(video_path) + if not resolved_path: + return "", "Please upload or record a video file first." + + try: + video_tokens = self._extract_video_tokens(resolved_path) + except Exception as exc: + return "", f"Failed to process video: {exc}" + spt = self.uni_prompting.sptids_dict + + question = random.choice(V2T_INSTRUCTION) + prompt = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{question}" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + prompt_ids = self.uni_prompting.text_tokenizer( + prompt, + return_tensors="pt", + ).input_ids.to(self.device) + + input_ids = torch.cat( + [ + spt['<|v2t|>'].to(self.device).unsqueeze(0), + spt['<|soi|>'].to(self.device).unsqueeze(0), + video_tokens, + spt['<|eoi|>'].to(self.device).unsqueeze(0), + spt['<|sot|>'].to(self.device).unsqueeze(0), + prompt_ids, + ], + dim=1, + ).long() + + with torch.no_grad(): + output_ids = self.model.mmu_generate( + input_ids, + max_new_tokens=int(max_new_tokens), + steps=int(steps), + block_length=int(block_length), + ) + + decoded = self.uni_prompting.text_tokenizer.batch_decode( + output_ids[:, input_ids.shape[1]:], + skip_special_tokens=True, + )[0] + status_msg = "Video caption generated successfully." + if converted: + status_msg += " (Webcam recording converted to MP4.)" + return decoded.strip(), status_msg + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _resolve_choice(self, choice: Optional[str], options): + if choice is None or choice == 'random': + return random.choice(options) + return choice + + def _prepare_video_path(self, video_input: Any) -> Tuple[Optional[str], bool]: + """Normalize Gradio video inputs (upload/webcam) to an MP4 filepath.""" + + candidate = None + if isinstance(video_input, str): + candidate = video_input + elif isinstance(video_input, dict): + candidate = ( + video_input.get("video") + or video_input.get("name") + or video_input.get("path") + ) + elif isinstance(video_input, (list, tuple)) and video_input: + candidate = str(video_input[0]) + + if not candidate: + return None, False + + candidate = str(candidate) + if not self._ensure_file_ready(candidate): + return None, False + + if candidate.lower().endswith(".mp4"): + return candidate, False + + converted = self._convert_to_mp4(candidate) + if converted: + return converted, True + suffix = Path(candidate).suffix or ".webm" + fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) + os.close(fd) + try: + shutil.copy2(candidate, tmp_path) + self._temp_video_files.append(tmp_path) + return tmp_path, False + except OSError: + try: + os.remove(tmp_path) + except OSError: + pass + return candidate, False + + if candidate.lower().endswith(".mp4"): + return candidate, False + + converted = self._convert_to_mp4(candidate) + if converted: + return converted, True + suffix = Path(candidate).suffix or ".webm" + fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) + os.close(fd) + try: + shutil.copy2(candidate, tmp_path) + self._temp_video_files.append(tmp_path) + return tmp_path, False + except OSError: + try: + os.remove(tmp_path) + except OSError: + pass + return candidate, False + + def _ensure_file_ready(self, path: str, retries: int = 8, delay: float = 0.2) -> bool: + """Ensure the uploaded/recorded file is fully written before processing.""" + + prev_size = -1 + for _ in range(retries): + try: + size = os.path.getsize(path) + except OSError: + size = -1 + if size <= 0: + time.sleep(delay) + continue + if size == prev_size: + return True + prev_size = size + time.sleep(delay) + return prev_size > 0 + + def _convert_to_mp4(self, src_path: str) -> Optional[str]: + """Convert arbitrary video file to MP4 using OpenCV (drops audio).""" + + cap = cv2.VideoCapture(src_path) + if not cap.isOpened(): + return None + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) + fps = cap.get(cv2.CAP_PROP_FPS) + + if width <= 0 or height <= 0: + cap.release() + return None + + if not fps or np.isnan(fps) or fps <= 0: + fps = 24.0 + + fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_", suffix=".mp4") + os.close(fd) + + writer = cv2.VideoWriter( + tmp_path, + cv2.VideoWriter_fourcc(*"mp4v"), + float(fps), + (width, height), + ) + if not writer.isOpened(): + cap.release() + try: + os.remove(tmp_path) + except OSError: + pass + return None + + frame_count = 0 + try: + while True: + ret, frame = cap.read() + if not ret: + break + writer.write(frame) + frame_count += 1 + finally: + cap.release() + writer.release() + + if frame_count == 0: + try: + os.remove(tmp_path) + except OSError: + pass + return None + + self._temp_video_files.append(tmp_path) + return tmp_path + + def _extract_video_tokens(self, video_path: str) -> torch.Tensor: + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if total_frames <= 0: + cap.release() + raise RuntimeError(f"No readable frames in {video_path}") + + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for idx in range(total_frames): + ret, frame = cap.read() + if idx in indices and ret: + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil = Image.fromarray(rgb) + frames.append(image_transform(pil, resolution=self.train_cfg.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) == 0: + raise RuntimeError("Failed to sample frames for V2T inference.") + + video_tensor = torch.stack(frames).to(self.device) + video_tokens = self.vq_image.get_code(video_tensor) + self.text_vocab_size + return video_tokens.long().to(self.device).view(1, -1) + + +def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]): + theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray") + with gr.Blocks(title="OMada Stage1.3 Audio/Video Demo", css=CUSTOM_CSS, theme=theme) as demo: + with gr.Column(elem_classes=["omada-header"]): + if LOGO_DATA_URI: + gr.HTML( + f"
\"AIDAS
" + ) + elif LOGO_PATH.exists(): + gr.Image( + value=str(LOGO_PATH), + show_label=False, + height=140, + interactive=False, + elem_classes=["omada-logo"], + ) + gr.Markdown( + "## Omni-modal Diffusion Foundation Model\n### Pretrained Demo", + elem_classes=["omada-title"], + ) + gr.Markdown( + "Create speech, transcribe audio, and describe video from a single model. " + "Use the advanced sections when you want tighter control.", + elem_classes=["omada-tagline"], + ) + with gr.Row(elem_classes=["omada-layout"], equal_height=False): + with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]): + chatbox = gr.Chatbot(label="Session", height=420, sanitize_html=False) + placeholder_map = { + "Text → Speech": "Type the speech you want to generate...", + "Speech → Text": "Upload audio on the right, then leave notes here if needed.", + "Video → Text": "Upload video on the right, then leave notes here if needed.", + } + chat_input = gr.Textbox( + label="Message", + placeholder=placeholder_map["Text → Speech"], + lines=3, + ) + with gr.Row(): + send_button = gr.Button("Send", variant="primary") + clear_button = gr.Button("Clear", variant="secondary") + with gr.Column(scale=2, min_width=360, elem_classes=["omada-controls"]): + mode_selector = gr.Dropdown( + ["Text → Speech", "Speech → Text", "Video → Text"], + value="Text → Speech", + label="Mode", + ) + with gr.Column(visible=True, elem_classes=["omada-mode-panel"]) as t2s_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Text-to-Speech Controls") + with gr.Group(elem_classes=["omada-advanced"]): + gr.Markdown("**Generation**") + with gr.Row(): + t2s_max_tokens = gr.Slider(2, 512, value=128, label="Speech token length", step=2) + t2s_steps = gr.Slider(2, 512, value=128, label="Total refinement steps", step=2) + with gr.Row(): + t2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2) + t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, label="CFG scale", step=0.1) + t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) + with gr.Group(elem_classes=["omada-advanced"]): + gr.Markdown("**Voice styling**") + with gr.Row(): + t2s_gender = gr.Dropdown(['random'] + app.genders, value='random', label="Voice gender") + t2s_emotion = gr.Dropdown(['random'] + app.emotions, value='random', label="Emotion") + with gr.Row(): + t2s_speed = gr.Dropdown(['random'] + app.speeds, value='random', label="Speaking speed") + t2s_pitch = gr.Dropdown(['random'] + app.pitches, value='random', label="Pitch") + if T2S_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample prompts**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=T2S_EXAMPLES, + inputs=[chat_input], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2t_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Speech-to-Text Controls") + s2t_audio = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"]) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + with gr.Row(): + s2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2) + s2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2) + s2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2) + s2t_remasking = gr.Dropdown( + choices=["low_confidence", "random"], + value="low_confidence", + label="Remasking strategy", + ) + if S2T_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample clips**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=S2T_EXAMPLES, + inputs=[s2t_audio], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2t_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Video-to-Text Controls") + v2t_video = gr.Video( + label="Upload or record video", + format=None, + height=256, + sources=["upload", "webcam"], + ) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + with gr.Row(): + v2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2) + v2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2) + v2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2) + if V2T_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample videos**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=V2T_EXAMPLES, + inputs=[v2t_video], + examples_per_page=4, + ) + + def _toggle_controls(mode: str): + return ( + gr.update(visible=mode == "Text → Speech"), + gr.update(visible=mode == "Speech → Text"), + gr.update(visible=mode == "Video → Text"), + gr.update(placeholder=placeholder_map.get(mode, chat_input.placeholder)), + ) + + mode_selector.change( + _toggle_controls, + inputs=[mode_selector], + outputs=[t2s_panel, s2t_panel, v2t_panel, chat_input], + ) + + def _chat_handler( + history, + message, + mode, + audio_path, + video_path, + t2s_max_tokens, + t2s_steps, + t2s_block, + t2s_temperature, + t2s_cfg, + t2s_gender, + t2s_emotion, + t2s_speed, + t2s_pitch, + s2t_steps, + s2t_block, + s2t_max_tokens, + s2t_remasking, + v2t_steps, + v2t_block, + v2t_max_tokens, + ): + history = history or [] + message = (message or "").strip() + response = "" + + if mode == "Text → Speech": + if not message: + status = "Please type some text for speech synthesis." + response = _render_text_message(status, "") + else: + audio_result, status = app.run_t2s( + message, + t2s_max_tokens, + t2s_steps, + t2s_block, + t2s_temperature, + t2s_cfg, + t2s_gender, + t2s_emotion, + t2s_speed, + t2s_pitch, + ) + response = _render_audio_message(status, audio_result) + display_user_raw = message or "[Speech request]" + elif mode == "Speech → Text": + if not audio_path: + status = "Please upload or record an audio clip first." + response = _render_text_message(status, "") + else: + transcript, status = app.run_s2t( + audio_path, + s2t_steps, + s2t_block, + s2t_max_tokens, + s2t_remasking, + ) + response = _render_text_message(status, transcript) + display_user_raw = message or "[Audio transcription request]" + else: # Video → Text + if not video_path: + status = "Please upload or record a video first." + response = _render_text_message(status, "") + else: + caption, status = app.run_v2t( + video_path, + v2t_steps, + v2t_block, + v2t_max_tokens, + ) + response = _render_text_message(status, caption) + display_user_raw = message or "[Video caption request]" + + display_user = _format_user_message(display_user_raw) + history = history + [(display_user, response)] + return history, "" + + submit_inputs = [ + chatbox, + chat_input, + mode_selector, + s2t_audio, + v2t_video, + t2s_max_tokens, + t2s_steps, + t2s_block, + t2s_temperature, + t2s_cfg, + t2s_gender, + t2s_emotion, + t2s_speed, + t2s_pitch, + s2t_steps, + s2t_block, + s2t_max_tokens, + s2t_remasking, + v2t_steps, + v2t_block, + v2t_max_tokens, + ] + submit_outputs = [chatbox, chat_input] + + chat_input.submit(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) + send_button.click(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) + + def _clear_session(): + return ( + [], + "", + gr.update(value=None), + gr.update(value=None), + ) + + clear_button.click( + _clear_session, + inputs=None, + outputs=[chatbox, chat_input, s2t_audio, v2t_video], + ) + + demo.launch( + share=share, + server_name=server_name, + server_port=server_port, + ) + + +def parse_args(): + parser = argparse.ArgumentParser(description="OMada Gradio demo for audio/video tasks") + parser.add_argument("--train-config", required=True, help="Path to the training YAML used to build tokenizer + VQ modules") + parser.add_argument("--checkpoint", required=True, help="Path to an `unwrapped_model` directory") + parser.add_argument("--device", default=None, help="Override device (e.g. cuda:0). Defaults to CUDA if available") + parser.add_argument("--share", action="store_true", help="Enable public Gradio share link") + parser.add_argument("--server-name", default="0.0.0.0", help="Host address for Blocks.launch") + parser.add_argument("--server-port", type=int, default=None, help="Port for Blocks.launch") + return parser.parse_args() + + +def main(): + args = parse_args() + app = OmadaDemo(args.train_config, args.checkpoint, args.device) + build_demo(app, args.share, args.server_name, args.server_port) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/inference/gradio_multimodal_demo_inst.py b/MMaDA/inference/gradio_multimodal_demo_inst.py new file mode 100644 index 0000000000000000000000000000000000000000..b680d619caf0d986578bdc6f011d6914c2487a34 --- /dev/null +++ b/MMaDA/inference/gradio_multimodal_demo_inst.py @@ -0,0 +1,2469 @@ +#!/usr/bin/env python3 +""" +Gradio demo for OMada Stage 1.3 checkpoints covering: + * Text-to-Speech (T2S) + * Speech-to-Text (S2T) + * Video-to-Text (V2T) + +The implementation wraps the existing CLI inference helpers so that a single +checkpoint directory (…/checkpoint-XXXX/unwrapped_model) can be previewed +interactively. Usage: + + python MMaDA/inference/gradio_multimodal_demo_inst.py --train-config /t1data/users/snu-lab-d/omada/OMaDA/MMaDA/inference/demo/demo.yaml --checkpoint ../ckpt/checkpoint-400000/unwrapped_model/ --server-port 7860 + +If you need remote access, pass `--share` which simply forwards the flag to +`gradio.Blocks.launch`. For more reliable sharing, run the demo locally without +`--share` and tunnel the chosen port with a tool such as `ngrok http 7860` or +`cloudflared tunnel --url http://localhost:7860` instead of relying on Gradio’s +temporary share links. +""" + +import argparse +import base64 +import html +import io +import os +import math +import random +import sys +import tempfile +import wave +from pathlib import Path +import shutil +import time +from typing import Any, List, Optional, Tuple +import numpy as np +import torch.nn.functional as F +from PIL import Image + +CUSTOM_CSS = """ +:root { + --omada-primary: #1e3a8a; + --omada-accent: #1d4ed8; + --omada-surface: #f3f4f6; + --omada-surface-alt: #ffffff; + --omada-border: #d0d7e5; + --omada-text-primary: #111827; + --omada-text-muted: #374151; + color-scheme: light; +} +html, body, body.dark, html.dark { + background: var(--omada-surface) !important; + color: var(--omada-text-primary) !important; +} +.gradio-container { + background: var(--omada-surface); + color: var(--omada-text-primary); +} +.omada-page-heading { + margin-bottom: 0; + color: var(--omada-text-primary); +} +.omada-tab-intro p { + font-size: 0.95rem; + color: var(--omada-text-primary); + margin-top: 0; + opacity: 0.9; +} +.omada-card { + background: var(--omada-surface-alt); + border-radius: 16px !important; + padding: 18px !important; + box-shadow: none; + border: 1px solid var(--omada-border); + color: var(--omada-text-primary); +} +.omada-card .gradio-slider .wrap-inner { + gap: 6px; +} +.omada-card .gradio-slider input[type=range]::-webkit-slider-thumb { + background: var(--omada-primary); +} +.gradio-slider input[type=range]::-webkit-slider-runnable-track { + background: rgba(14, 33, 80, 0.2); +} +.omada-section-title p { + text-transform: uppercase; + font-size: 0.78rem; + letter-spacing: 0.14em; + color: rgba(30, 58, 138, 0.85); + margin: 0 0 12px 0; +} +.omada-output .gradio-audio, .omada-output .gradio-textbox { + margin-top: 12px; +} +.gradio-textbox, .gradio-dropdown, .gradio-slider, .gradio-audio { + color: var(--omada-text-primary) !important; +} +.gradio-dropdown .wrap .label, .gradio-textbox label, .gradio-slider label { + color: var(--omada-text-primary); +} +.gradio-dropdown .single-select, .gradio-textbox textarea { + background: #ffffff !important; + border: 1px solid var(--omada-border) !important; + color: var(--omada-text-primary) !important; +} +.gradio-dropdown .single-select select, .gradio-textbox textarea { + color: var(--omada-text-primary) !important; +} +.gradio-textbox textarea::placeholder { + color: rgba(148, 163, 184, 0.65); +} +.gradio-dropdown, .gradio-textbox, .gradio-audio, .gradio-video, .gradio-slider { + background: #ffffff !important; + border: 1px solid var(--omada-border) !important; + border-radius: 12px !important; +} +.full-width-button button { + width: 100%; + background: var(--omada-primary) !important; + color: white !important; + border: none !important; + font-weight: 600; + transition: transform 0.2s ease, box-shadow 0.2s ease; + box-shadow: 0 12px 30px -12px rgba(79, 70, 229, 0.65); +} +.full-width-button button:hover { + transform: translateY(-1px); + box-shadow: 0 18px 34px -14px rgba(79, 70, 229, 0.75); +} +.omada-advanced .gr-accordion-header { + font-size: 0.85rem; + letter-spacing: 0.05em; + color: var(--omada-text-muted); +} +.omada-advanced .gr-accordion { + border: 1px solid var(--omada-border); + border-radius: 12px; + background: #ffffff; +} +.gradio-tabs { + background: transparent; +} +.gradio-tabs ul.tab-list { + background: transparent; + border-bottom: 1px solid var(--omada-border); +} +.gradio-tabs button { + color: var(--omada-text-primary); +} +.gradio-tabs button.selected { + color: var(--omada-text-primary); + background: rgba(14, 33, 80, 0.1); + border-bottom: 2px solid var(--omada-primary); +} +.gradio-container .label { + background: rgba(30, 58, 138, 0.1) !important; + color: var(--omada-primary) !important; + border: 1px solid rgba(30, 58, 138, 0.25) !important; + border-radius: 999px !important; + padding: 4px 12px !important; +} +.gradio-button.primary { + background: var(--omada-primary) !important; + color: #ffffff !important; + border: 1px solid var(--omada-primary) !important; +} +.gradio-accordion { + box-shadow: none; +} +.omada-layout { + gap: 20px !important; +} +.omada-chat-column { + gap: 12px !important; +} +.omada-chat-column .gradio-chatbot { + border-radius: 16px; + box-shadow: none; + border: 1px solid var(--omada-border); + background: #ffffff; +} +.omada-controls { + gap: 16px !important; +} +.omada-mode-panel { + display: flex; + flex-direction: column; + gap: 16px !important; +} +.omada-examples-card { + padding-top: 10px !important; +} +.omada-output-panel .gradio-audio, +.omada-output-panel .gradio-textbox { + margin-top: 8px; +} +.omada-response-container { + display: flex; + flex-direction: column; + gap: 10px; +} +.omada-response-status { + margin: 0; + font-weight: 600; + font-size: 0.95rem; + color: var(--omada-text-primary); +} +.omada-response-block, +.omada-audio-block { + background: rgba(30, 58, 138, 0.05); + border-radius: 12px; + padding: 12px 14px; + color: var(--omada-text-primary); + white-space: pre-wrap; + word-break: break-word; +} +.omada-audio-block audio { + width: 100%; +} +.omada-header { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + gap: 18px !important; + margin-bottom: 18px; + text-align: center; +} +.omada-header .gradio-image { + background: transparent !important; + border: none !important; +} +.omada-header img { + object-fit: contain; + display: block; +} +.omada-logo { + max-width: 180px; + padding: 0 !important; +} +.omada-examples { + margin-top: 8px; + padding-top: 4px; +} +.omada-examples .gradio-dataset { + width: 100% !important; +} +.omada-examples .samples-table { + width: 100% !important; +} +.omada-examples table { + width: 100% !important; +} +.omada-examples td { + width: 100% !important; +} +.omada-examples .sample-button, +.omada-examples button { + width: 100% !important; + white-space: pre-wrap !important; + word-wrap: break-word !important; + height: auto !important; + min-height: 40px !important; + text-align: left !important; + padding: 12px 16px !important; + line-height: 1.5 !important; + overflow: visible !important; + text-overflow: clip !important; + display: block !important; +} +.omada-examples button span { + white-space: pre-wrap !important; + word-wrap: break-word !important; + overflow: visible !important; + text-overflow: clip !important; + display: block !important; + width: 100% !important; +} +.omada-logo .gradio-image, +.omada-logo .gradio-image > div, + .omada-logo .gradio-image .container { + background: transparent !important; + border: none !important; + box-shadow: none !important; + padding: 0 !important; + display: flex; + justify-content: center; + align-items: center; + } + .omada-logo img { + width: 100%; + height: auto; + } +.omada-logo button { + display: none !important; +} +.gradio-container .gradio-component, +.gradio-container .gradio-panel, +.gradio-container .gradio-box { + background: transparent !important; + color: var(--omada-text-primary); +} +.dark .gradio-container, +.dark .gradio-interface, +.dark .gradio-container * { + background-color: inherit !important; + color: var(--omada-text-primary) !important; +} +.dark .gradio-container .gradio-chatbot, +.dark .gradio-container .gradio-dropdown, +.dark .gradio-container .gradio-textbox, +.dark .gradio-container .gradio-audio, +.dark .gradio-container .gradio-video, +.dark .gradio-container .gradio-slider, +.dark .gradio-container .gradio-accordion, +.dark .gradio-container .gradio-panel, +.dark .gradio-container .gradio-box { + background: #ffffff !important; + border-color: var(--omada-border) !important; +} +.omada-title h2 { + font-size: 2.4rem; + font-weight: 700; + color: var(--omada-text-primary); + margin: 0; +} +.omada-title h3 { + font-size: 1.25rem; + font-weight: 600; + letter-spacing: 0.1em; + text-transform: uppercase; + color: var(--omada-text-muted); + margin: 6px 0 0; +} +.omada-tagline p { + color: var(--omada-text-primary); + font-size: 1rem; + margin: 0; + opacity: 0.9; + line-height: 1.45; +} +.omada-tagline .tagline-speech { color: #f97316; font-weight: 600; } +.omada-tagline .tagline-audio { color: #ec4899; font-weight: 600; } +.omada-tagline .tagline-video { color: #0ea5e9; font-weight: 600; } +.omada-tagline .tagline-text { color: #a855f7; font-weight: 600; } +.omada-tagline .tagline-image { color: #22c55e; font-weight: 600; } +.gradio-container .prose :where(h1, h2, h3, h4, h5, h6) { + color: var(--omada-text-primary) !important; +} +.gradio-container .prose :where(p, li) { + color: var(--omada-text-muted) !important; +} +.gradio-container label, .gradio-container span, .gradio-container button { + color: var(--omada-text-primary); +} +.gradio-container .dark { + background: #ffffff !important; + color: var(--omada-text-primary) !important; +} +.omada-logo-img { + max-width: 250px; + width: 100%; + height: auto; + display: block; + margin: 0 auto; +} +.omada-logo-wrapper { + display: flex; + justify-content: center; + align-items: center; +} +""" + +FORCE_LIGHT_MODE_JS = """ +function() { + document.body.classList.remove('dark'); + document.documentElement.classList.remove('dark'); + const observer = new MutationObserver(function(mutations) { + document.body.classList.remove('dark'); + document.documentElement.classList.remove('dark'); + }); + observer.observe(document.body, { attributes: true, attributeFilter: ['class'] }); + observer.observe(document.documentElement, { attributes: true, attributeFilter: ['class'] }); +} +""" + +DEMO_ROOT = Path(__file__).resolve().parent / "demo" +LOGO_PATH = DEMO_ROOT / "logo.png" +T2S_TEXT_PATH = DEMO_ROOT / "t2s" / "text.txt" +CHAT_TEXT_PATH = DEMO_ROOT / "chat" / "text.txt" +T2I_TEXT_PATH = DEMO_ROOT / "t2i" / "text.txt" + + +def _load_logo_data() -> Optional[str]: + if not LOGO_PATH.exists(): + return None + try: + import base64 + except ImportError: + return str(LOGO_PATH) + try: + encoded = base64.b64encode(LOGO_PATH.read_bytes()).decode("utf-8") + except OSError: + return str(LOGO_PATH) + return f"data:image/png;base64,{encoded}" + + +def _load_t2s_examples(): + if not T2S_TEXT_PATH.exists(): + return [] + lines = [ + line.strip() + for line in T2S_TEXT_PATH.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + return [[line] for line in lines] + + +def _load_chat_examples(): + if not CHAT_TEXT_PATH.exists(): + return [] + lines = [ + line.strip() + for line in CHAT_TEXT_PATH.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + return [[line] for line in lines] + + +def _load_t2i_examples(): + if not T2I_TEXT_PATH.exists(): + return [] + lines = [ + line.strip() + for line in T2I_TEXT_PATH.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + return [[line] for line in lines] + + +def _load_media_examples(subdir: str, suffixes): + target_dir = DEMO_ROOT / subdir + if not target_dir.exists(): + return [] + examples = [] + for path in sorted(target_dir.iterdir()): + if path.is_file() and path.suffix.lower() in suffixes: + examples.append([str(path)]) + return examples + + +T2S_EXAMPLES = _load_t2s_examples() +CHAT_EXAMPLES = _load_chat_examples() +T2I_EXAMPLES = _load_t2i_examples() +S2T_EXAMPLES = _load_media_examples("s2t", {".wav", ".mp3", ".flac", ".ogg"}) +V2T_EXAMPLES = _load_media_examples("v2t", {".mp4", ".mov", ".avi", ".webm"}) +S2S_EXAMPLES = _load_media_examples("s2s", {".wav", ".mp3", ".flac", ".ogg"}) +if not S2S_EXAMPLES: + S2S_EXAMPLES = S2T_EXAMPLES[: min(4, len(S2T_EXAMPLES))] +V2S_EXAMPLES = _load_media_examples("v2s", {".mp4", ".mov", ".avi", ".webm"}) +if not V2S_EXAMPLES: + V2S_EXAMPLES = V2T_EXAMPLES[: min(4, len(V2T_EXAMPLES))] +I2S_EXAMPLES = _load_media_examples("i2s", {".png", ".jpg", ".jpeg", ".webp"}) +LOGO_DATA_URI = _load_logo_data() + +MMU_IMAGE_A = DEMO_ROOT / "mmu" / "1.jpg" +MMU_IMAGE_B = DEMO_ROOT / "mmu" / "2.jpg" +# MMU_IMAGE_C = DEMO_ROOT / "mmu" / "SD_IMG_00235_1.png" +# MMU_IMAGE_D = DEMO_ROOT / "mmu" / "SD_IMG_00235_2.png" +if MMU_IMAGE_A.exists() and MMU_IMAGE_B.exists(): + MMU_EXAMPLES = [ + # [ + # str(MMU_IMAGE_C), + # str(MMU_IMAGE_D), + # "What are the differences between the two images?" + # ], + [ + str(MMU_IMAGE_A), + str(MMU_IMAGE_B), + "What are the differences in coloring and physical features between animal1 and animal2 in the bird images?", + ] + ] +else: + MMU_EXAMPLES = [] +if not I2S_EXAMPLES and MMU_EXAMPLES: + I2S_EXAMPLES = [[example[0]] for example in MMU_EXAMPLES] + + +def _render_response(status: str, body_html: str = "") -> str: + safe_status = html.escape(status or "") + parts = [] + if safe_status: + parts.append(f"

{safe_status}

") + if body_html: + parts.append(body_html) + content = "".join(parts) + return f"
{content}
" + + +def _render_text_message(status: str, content: Optional[str]) -> str: + # post-processing + if content: + remove_tokens = ["<|eot_id|>", "<|eot_id>", "", "", "", "assistant"] + for token in remove_tokens: + content = content.replace(token, "") + + content = (content or "").strip() + if not content: + return _render_response(status) + safe_content = html.escape(content).replace("\n", "
") + body = f"
{safe_content}
" + return _render_response(status, body) + + +def _render_audio_message(status: str, audio: Optional[Tuple[int, np.ndarray]]) -> str: + """Render an inline HTML audio player for chat responses.""" + + if not audio: + return _render_response(status) + + sample_rate, data = audio + if data is None: + return _render_response(status) + + waveform = np.asarray(data, dtype=np.float32) + if waveform.size == 0: + return _render_response(status) + + if waveform.ndim == 1: + waveform = waveform[:, None] + + channels = waveform.shape[1] + clipped = np.clip(waveform, -1.0, 1.0) + pcm16 = (clipped * 32767.0).astype(np.int16) + + buffer = io.BytesIO() + with wave.open(buffer, "wb") as wav_writer: + wav_writer.setnchannels(channels) + wav_writer.setsampwidth(2) # 16-bit PCM + wav_writer.setframerate(int(sample_rate)) + wav_writer.writeframes(pcm16.tobytes()) + + encoded = base64.b64encode(buffer.getvalue()).decode("ascii") + audio_tag = ( + "
" + "" + "
" + ) + return _render_response(status, audio_tag) + + +def _render_image_message(status: str, image: Optional[Image.Image]) -> str: + if image is None: + return _render_response(status) + + buffer = io.BytesIO() + try: + image.save(buffer, format="PNG") + except Exception: + return _render_response(status) + + encoded = base64.b64encode(buffer.getvalue()).decode("ascii") + image_html = ( + "
" + "Generated image" + "
" + ) + return _render_response(status, image_html) + + +def _format_user_message(message: str) -> str: + clean = html.escape(message or "") + return clean.replace("\n", "
") + +# Ensure project modules (models, training, inference.common, …) are importable when the +# script is launched directly via `python MMaDA/inference/gradio_multimodal_demo.py`. +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +import cv2 +import gradio as gr +import numpy as np +import torch +from PIL import Image + +from inference.common import ( + build_uni_prompting, + get_vq_model_audio, + get_vq_model_image, + load_omada_from_checkpoint, + load_train_config, +) +from models import get_mask_schedule +from models.modeling_omada import add_gumbel_noise, get_num_transfer_tokens +from training.data import S2T_INSTRUCTION, T2S_INSTRUCTION, V2T_INSTRUCTION, V2S_INSTRUCTION +from training.utils import image_transform + + +def _resolve_noise_schedule(train_cfg) -> callable: + """Return the diffusion noise schedule used for T2S sampling.""" + + schedule_cfg = getattr(train_cfg, "mask_schedule", None) + if schedule_cfg and hasattr(schedule_cfg, "schedule"): + schedule_name = schedule_cfg.schedule + schedule_kwargs = schedule_cfg.get("params", {}) + return get_mask_schedule(schedule_name, **schedule_kwargs) + + schedule_name = train_cfg.training.get("mask_schedule", "cosine") + return get_mask_schedule(schedule_name) + + +class OmadaDemo: + """Lightweight container that loads all inference assets once.""" + + def __init__(self, train_config: str, checkpoint: str, device: Optional[str] = None): + ckpt_path = Path(checkpoint) + if ckpt_path.name != "unwrapped_model": + raise ValueError( + "`--checkpoint` must point to an `unwrapped_model` directory. " + f"Received: {checkpoint}" + ) + + self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) + self.train_cfg = load_train_config(train_config) + self.uni_prompting, _ = build_uni_prompting(self.train_cfg) + + # Core models + self.model = load_omada_from_checkpoint(str(ckpt_path), self.device) + self.vq_audio = get_vq_model_audio(self.train_cfg, self.device) + self.vq_image = get_vq_model_image(self.train_cfg, self.device) + + self.model.eval() + self.vq_audio.eval() + self.vq_image.eval() + + # Cached constants + self.mask_token_id = int(self.model.config.mask_token_id) + self.noise_schedule = _resolve_noise_schedule(self.train_cfg) + self.sample_rate = int(getattr(self.vq_audio.u2s_config.data, "sampling_rate", 22050)) + + self.genders = ['female', 'male'] + self.emotions = ['angry', 'happy', 'neutral', 'sad'] + self.speeds = ['normal', 'fast', 'slow'] + self.pitches = ['normal', 'high', 'low'] + + # Pre-computed offsets reused across calls + self.text_vocab_size = len(self.uni_prompting.text_tokenizer) + # The current checkpoints assume a fixed 8k vision codebook and 4k speech codebook. + self.codebook_size = 8192 + self.speech_codebook = self.codebook_size + self.audio_codebook_size = 4096 + + self.max_audio_len_short = int( + getattr( + self.uni_prompting, + "max_audio_len_short", + getattr(self.train_cfg.dataset.preprocessing, "max_aud_length_short", 256), + ) + ) + self.max_text_len = int(getattr(self.train_cfg.dataset.preprocessing, "max_seq_length", 1024)) + + image_seq_len = getattr(self.model.config, "num_vq_tokens", None) + if image_seq_len is None: + image_seq_len = getattr(getattr(self.train_cfg.model, "omada", None), "num_vq_tokens", None) + if image_seq_len is None: + image_seq_len = getattr(getattr(self.train_cfg.model, "vq_model_image", None), "num_vq_tokens", None) + if image_seq_len is None: + image_seq_len = 1024 + self.image_seq_len = int(image_seq_len) + print(self.image_seq_len) + self.image_resolution = int(getattr(self.train_cfg.dataset.preprocessing, "resolution", 256)) + self.image_noise_schedule = _resolve_noise_schedule(self.train_cfg) + + self.audio_condition_default = "gender-female_emotion-neutral_speed-normal_pitch-normal" + style_map = getattr(getattr(self.vq_audio, "config", None), "u2s_style2idx", None) + if isinstance(style_map, dict): + self._valid_conditions = set(style_map.keys()) + if self._valid_conditions and self.audio_condition_default not in self._valid_conditions: + # Ensure the default condition is valid for the tokenizer. + self.audio_condition_default = next(iter(self._valid_conditions)) + else: + self._valid_conditions = set() + self._temp_video_files = [] + + # ------------------------------------------------------------------ + # Text-to-Speech + # ------------------------------------------------------------------ + def run_t2s( + self, + text: str, + max_new_tokens: int, + steps: int, + block_length: int, + temperature: float, + cfg_scale: float, + gender_choice: str, + emotion_choice: str, + speed_choice: str, + pitch_choice: str, + ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: + + if text is None or not text.strip(): + return None, "Please provide text to synthesize." + + speech_len, steps, block_length = self._prepare_block_schedule( + max_new_tokens, + steps, + block_length, + ) + + gender = self._resolve_choice(gender_choice, self.genders) + emotion = self._resolve_choice(emotion_choice, self.emotions) + speed = self._resolve_choice(speed_choice, self.speeds) + pitch = self._resolve_choice(pitch_choice, self.pitches) + + text = text.strip().upper() + prompt = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{random.choice(T2S_INSTRUCTION)}\n{text}" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + + audio_tokens = torch.full( + (1, speech_len), + fill_value=self.mask_token_id, + dtype=torch.long, + device=self.device, + ) + + input_ids, attention_mask = self.uni_prompting(([prompt], audio_tokens), "t2s_gen") + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + with torch.no_grad(): + outputs = self.model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=int(speech_len), + steps=int(steps), + block_length=int(block_length), + temperature=float(temperature), + cfg_scale=float(cfg_scale), + mask_token_id=self.mask_token_id, + attention_mask=attention_mask, + uni_prompting=self.uni_prompting, + codebook_size=self.codebook_size, + ) + + if not outputs: + return None, "Generation produced no speech tokens." + + rel = outputs[0] + if isinstance(rel, torch.Tensor): + rel_ids = rel.detach().cpu().tolist() + else: + rel_ids = list(rel) + + if not rel_ids: + return None, "Generation produced no speech tokens." + + speech_units = "".join(f"<|speech_{sid}|>" for sid in rel_ids) + condition = f"gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}" + wav = self.vq_audio.decode( + speech_units, + condition=condition, + output_wav_file=os.path.join("/tmp", "omada_t2s.wav"), + ) + + audio = (self.sample_rate, wav.astype(np.float32)) + status = f"Speech generated! ({gender}/{emotion}/{speed}/{pitch})." + return audio, status + + # ------------------------------------------------------------------ + # Speech-to-Speech + # ------------------------------------------------------------------ + def run_s2s( + self, + audio_path: Optional[str], + max_new_tokens: int, + steps: int, + block_length: int, + temperature: float, + cfg_scale: float, + ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: + if not audio_path: + return None, "Please upload source speech first." + + try: + user_tokens = self.vq_audio.encode(audio_path) + except Exception as exc: + return None, f"Failed to encode input audio: {exc}" + + if not isinstance(user_tokens, torch.Tensor): + user_tokens = torch.tensor(user_tokens) + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + + user_tokens = user_tokens.to(self.device, dtype=torch.long) + if user_tokens.numel() == 0: + return None, "Uploaded speech clip produced no tokens." + + gen_len = max(1, int(max_new_tokens)) + gen_len = min(gen_len, self.max_audio_len_short) + gen_len, steps, block_length = self._prepare_block_schedule( + gen_len, + steps, + block_length, + ) + offset = self.text_vocab_size + self.codebook_size + user_shifted = user_tokens + offset + assistant_placeholder = torch.full( + (1, gen_len), + self.mask_token_id, + dtype=torch.long, + device=self.device, + ) + + input_ids, attention_mask = self.uni_prompting( + ([user_shifted], [assistant_placeholder]), + "s2s_gen", + ) + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + with torch.no_grad(): + outputs = self.model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=gen_len, + steps=int(steps), + block_length=int(block_length), + temperature=float(temperature), + cfg_scale=float(cfg_scale), + mask_token_id=self.mask_token_id, + attention_mask=attention_mask, + uni_prompting=self.uni_prompting, + codebook_size=self.codebook_size, + audio_codebook_size=self.audio_codebook_size, + ) + + if not outputs: + return None, "Generation returned no tokens." + + generated = outputs[0] + if isinstance(generated, torch.Tensor): + generated = generated.detach().cpu() + + eoa_token_id = int(self.uni_prompting.sptids_dict['<|eoa|>'][0].item()) + mask_token_id = int(self.mask_token_id) + token_list = [] + for tok in generated.tolist(): + tok = int(tok) + if tok < 0: + continue + if tok == eoa_token_id: + break + if tok == mask_token_id: + continue + if tok >= self.audio_codebook_size: + continue + token_list.append(tok) + if not token_list: + return None, "Generated sequence was empty after post-processing." + + speech_units = "".join(f"<|speech_{tok}|>" for tok in token_list) + condition = self._resolve_condition(self.audio_condition_default) + fd, temp_path = tempfile.mkstemp(prefix="omada_s2s_", suffix=".wav") + os.close(fd) + try: + wav = self.vq_audio.decode( + speech_units, + condition=condition, + output_wav_file=temp_path, + ) + finally: + try: + os.remove(temp_path) + except OSError: + pass + + audio = (self.sample_rate, wav.astype(np.float32)) + return audio, f"Speech response generated successfully. (voice: {condition})" + + # ------------------------------------------------------------------ + # Speech-to-Text + # ------------------------------------------------------------------ + def run_s2t( + self, + audio_path: Optional[str], + steps: int, + block_length: int, + max_new_tokens: int, + remasking: str, + ) -> Tuple[str, str]: + + if not audio_path: + return "", "Please upload an audio file first." + + tokens = self.vq_audio.encode(audio_path).to(self.device) + offset = self.text_vocab_size + self.speech_codebook + tokens = tokens + offset + + spt = self.uni_prompting.sptids_dict + audio_block = torch.cat( + [ + spt['<|s2t|>'].to(self.device).unsqueeze(0), + spt['<|soa|>'].to(self.device).unsqueeze(0), + tokens.to(self.device), + spt['<|eoa|>'].to(self.device).unsqueeze(0), + ], + dim=1, + ) + + prompt_text = random.choice(S2T_INSTRUCTION) + chat_prompt = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt_text}" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + prompt_tensor = self.uni_prompting.text_tokenizer( + chat_prompt, + return_tensors="pt", + ).input_ids.to(self.device) + + input_ids = torch.cat([audio_block, prompt_tensor], dim=1) + + with torch.no_grad(): + output_ids = self.model.mmu_generate( + input_ids, + max_new_tokens=int(max_new_tokens), + steps=int(steps), + block_length=int(block_length), + remasking=str(remasking), + ) + + decoded = self.uni_prompting.text_tokenizer.batch_decode( + output_ids[:, input_ids.shape[1]:], + skip_special_tokens=True, + )[0] + + return decoded.strip(), "Transcription generated successfully." + + # ------------------------------------------------------------------ + # Video-to-Text + # ------------------------------------------------------------------ + def run_v2t( + self, + video_path: Any, + steps: int, + block_length: int, + max_new_tokens: int, + ) -> Tuple[str, str]: + + resolved_path, converted = self._prepare_video_path(video_path) + if not resolved_path: + return "", "Please upload or record a video file first." + + try: + video_tokens = self._extract_video_tokens(resolved_path) + except Exception as exc: + return "", f"Failed to process video: {exc}" + spt = self.uni_prompting.sptids_dict + + question = random.choice(V2T_INSTRUCTION) + prompt = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{question}" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + prompt_ids = self.uni_prompting.text_tokenizer( + prompt, + return_tensors="pt", + ).input_ids.to(self.device) + + input_ids = torch.cat( + [ + spt['<|v2t|>'].to(self.device).unsqueeze(0), + spt['<|soi|>'].to(self.device).unsqueeze(0), + video_tokens, + spt['<|eoi|>'].to(self.device).unsqueeze(0), + spt['<|sot|>'].to(self.device).unsqueeze(0), + prompt_ids, + ], + dim=1, + ).long() + + with torch.no_grad(): + output_ids = self.model.mmu_generate( + input_ids, + max_new_tokens=int(max_new_tokens), + steps=int(steps), + block_length=int(block_length), + ) + + decoded = self.uni_prompting.text_tokenizer.batch_decode( + output_ids[:, input_ids.shape[1]:], + skip_special_tokens=True, + )[0] + status_msg = "Video caption generated successfully." + if converted: + status_msg += " (Webcam recording converted to MP4.)" + return decoded.strip(), status_msg + + # ------------------------------------------------------------------ + # Text-to-Image + # ------------------------------------------------------------------ + def run_t2i( + self, + prompt: str, + timesteps: int, + temperature: float, + guidance_scale: float, + ) -> Tuple[Optional[Image.Image], str]: + if not prompt or not prompt.strip(): + return None, "Please provide a text prompt." + + image_tokens = torch.full( + (1, self.image_seq_len), + self.mask_token_id, + dtype=torch.long, + device=self.device, + ) + input_ids, attention_mask = self.uni_prompting(([prompt.strip()], image_tokens), "t2i_gen") + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + if guidance_scale > 0: + uncond_ids, uncond_mask = self.uni_prompting(([""], image_tokens.clone()), "t2i_gen") + uncond_ids = uncond_ids.to(self.device) + uncond_mask = uncond_mask.to(self.device) + else: + uncond_ids = None + uncond_mask = None + + with torch.no_grad(): + gen_tokens = self.model.t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_mask, + guidance_scale=float(guidance_scale), + temperature=float(temperature), + timesteps=int(timesteps), + noise_schedule=self.image_noise_schedule, + seq_len=self.image_seq_len, + mask_token_id=self.mask_token_id, + codebook_size=self.codebook_size, + uni_prompting=self.uni_prompting, + config=self.train_cfg, + ) + + if gen_tokens is None: + return None, "Image generation failed." + + gen_tokens = torch.clamp(gen_tokens, min=0, max=self.codebook_size - 1) + image = self._decode_image_tokens(gen_tokens[0]) + return image, "Image generated from text prompt." + + # ------------------------------------------------------------------ + # Image-to-Image Editing + # ------------------------------------------------------------------ + def run_i2i( + self, + instruction: str, + source_image: Optional[Image.Image], + timesteps: int, + temperature: float, + guidance_scale: float, + ) -> Tuple[Optional[Image.Image], str]: + if source_image is None: + return None, "Please upload a reference image." + if not instruction or not instruction.strip(): + return None, "Provide editing instructions for the image." + + try: + input_tokens = self._prepare_image_tokens(source_image) + except Exception as exc: + return None, f"Failed to encode input image: {exc}" + + output_placeholder = torch.full( + (1, self.image_seq_len), + self.mask_token_id, + dtype=torch.long, + device=self.device, + ) + + input_ids, attention_mask = self.uni_prompting( + ([instruction.strip()], input_tokens, output_placeholder), + "i2i_gen", + ) + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + with torch.no_grad(): + gen_tokens = self.model.i2i_generate( + input_ids=input_ids, + attention_mask=attention_mask, + temperature=float(temperature), + timesteps=int(timesteps), + guidance_scale=float(guidance_scale), + noise_schedule=self.image_noise_schedule, + seq_len=self.image_seq_len, + mask_token_id=self.mask_token_id, + codebook_size=self.codebook_size, + uni_prompting=self.uni_prompting, + config=self.train_cfg, + ) + + if gen_tokens is None: + return None, "Image editing failed." + + gen_tokens = torch.clamp(gen_tokens, min=0, max=self.codebook_size - 1) + image = self._decode_image_tokens(gen_tokens[0]) + return image, "Edited image generated." + + # ------------------------------------------------------------------ + # Video-to-Speech + # ------------------------------------------------------------------ + def run_v2s( + self, + video_path: Any, + message: Optional[str], + max_new_tokens: int, + steps: int, + block_length: int, + temperature: float, + cfg_scale: float, + ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: + resolved_path, converted = self._prepare_video_path(video_path) + if not resolved_path: + return None, "Please upload or record a video first." + + try: + video_tokens = self._extract_video_tokens(resolved_path) + except Exception as exc: + return None, f"Failed to process video: {exc}" + + prompt_body = message.strip() if message and message.strip() else random.choice(V2S_INSTRUCTION) + prompt_text = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt_body}" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + + gen_len = max(1, int(max_new_tokens)) + gen_len = min(gen_len, self.max_audio_len_short) + gen_len, steps, block_length = self._prepare_block_schedule( + gen_len, + steps, + block_length, + ) + audio_placeholder = torch.full( + (1, gen_len), + self.mask_token_id, + dtype=torch.long, + device=self.device, + ) + + try: + seq_ids, attn_mask = self.uni_prompting( + (video_tokens, [prompt_text], [audio_placeholder]), + "v2s_gen", + ) + except Exception as exc: + return None, f"Failed to build V2S prompt: {exc}" + + input_ids = seq_ids.to(self.device) + attn_mask = attn_mask.to(self.device) + + with torch.no_grad(): + outputs = self.model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=gen_len, + steps=int(steps), + block_length=int(block_length), + temperature=float(temperature), + cfg_scale=float(cfg_scale), + mask_token_id=self.mask_token_id, + attention_mask=attn_mask, + uni_prompting=self.uni_prompting, + codebook_size=self.codebook_size, + audio_codebook_size=self.audio_codebook_size, + ) + + if not outputs: + return None, "Audio generation produced no tokens." + + generated = outputs[0] + if isinstance(generated, torch.Tensor): + generated = generated.detach().cpu() + + eoa_token_id = int(self.uni_prompting.sptids_dict['<|eoa|>'][0].item()) + mask_token_id = int(self.mask_token_id) + token_list = [] + for tok in generated.tolist(): + tok = int(tok) + if tok < 0: + continue + if tok == eoa_token_id: + break + if tok == mask_token_id: + continue + if tok >= self.audio_codebook_size: + continue + token_list.append(tok) + if not token_list: + return None, "Generated sequence was empty after decoding." + + speech_units = "".join(f"<|speech_{tok}|>" for tok in token_list) + fd, temp_path = tempfile.mkstemp(prefix="omada_v2s_", suffix=".wav") + os.close(fd) + condition = self._resolve_condition(self.audio_condition_default) + try: + wav = self.vq_audio.decode( + speech_units, + condition=condition, + output_wav_file=temp_path, + ) + except Exception as exc: + return None, f"Failed to decode speech: {exc}" + finally: + try: + os.remove(temp_path) + except OSError: + pass + + status = "Speech generated from video." + if converted: + status += " (Webcam recording converted to MP4.)" + status += f" (voice: {condition})" + return (self.sample_rate, wav.astype(np.float32)), status + + # ------------------------------------------------------------------ + # Image-to-Speech (This is a subset of s2s) + # ------------------------------------------------------------------ + def run_i2s( + self, + image: Optional[Image.Image], + message: Optional[str], + max_new_tokens: int, + steps: int, + block_length: int, + temperature: float, + cfg_scale: float, + ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]: + if image is None: + return None, "Please upload an image first." + + question = (message or "").strip() or "Please describe the image in spoken form." + caption, status = self._mmu_answer([image], question) + if not caption: + return None, status + + speech_len, steps, block_length = self._prepare_block_schedule( + max_new_tokens, + steps, + block_length, + ) + + audio_result, speech_status = self.run_t2s( + caption, + speech_len, + steps, + block_length, + temperature, + cfg_scale, + 'random', + 'random', + 'random', + 'random', + ) + if audio_result is None: + return None, speech_status + combined_status = f"{status} {speech_status}".strip() + return audio_result, combined_status or "Spoken description generated." + + # ------------------------------------------------------------------ + # Chat (Text Generation) + # ------------------------------------------------------------------ + def run_chat( + self, + message: str, + max_new_tokens: int, + steps: int, + block_length: int, + temperature: float, + ) -> Tuple[str, str]: + content = (message or "").strip() + if not content: + return "", "Type a message to start chatting." + + prompt = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{content}" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + tokenizer = self.uni_prompting.text_tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + tokens = tokenizer( + prompt, + return_tensors="pt", + truncation=True, + padding=True, + ) + input_ids = tokens["input_ids"].to(self.device) + attn_mask = tokens.get("attention_mask") + if attn_mask is not None: + attn_mask = attn_mask.to(self.device) + + with torch.no_grad(): + output_ids = self._generate_text_tokens( + input_ids, + max_new_tokens=int(max_new_tokens), + steps=int(steps), + block_length=int(block_length), + temperature=float(temperature), + cfg_scale=0.0, + attention_mask=attn_mask, + ) + + decoded = tokenizer.batch_decode( + output_ids[:, input_ids.shape[1]:], + skip_special_tokens=True, + )[0] + return decoded.strip(), "Assistant reply generated." + + # ------------------------------------------------------------------ + # Multi-image MMU (2 Images → Text) + # ------------------------------------------------------------------ + def run_mmu_dual( + self, + image_a: Optional[Image.Image], + image_b: Optional[Image.Image], + message: str, + max_new_tokens: int, + steps: int, + block_length: int, + temperature: float, + ) -> Tuple[str, str]: + images: List[Image.Image] = [] + if image_a is not None: + images.append(image_a) + if image_b is not None: + images.append(image_b) + if len(images) < 2: + return "", "Please provide two images for MMU reasoning." + + reply, status = self._mmu_answer( + images, + message, + max_new_tokens=max_new_tokens, + steps=steps, + block_length=block_length, + temperature=temperature, + ) + return reply, status + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _resolve_choice(self, choice: Optional[str], options): + if choice is None or choice == 'random': + return random.choice(options) + return choice + + def _sample_condition(self) -> str: + return ( + f"gender-{random.choice(self.genders)}" + f"_emotion-{random.choice(self.emotions)}" + f"_speed-{random.choice(self.speeds)}" + f"_pitch-{random.choice(self.pitches)}" + ) + + def _resolve_condition(self, preferred: Optional[str] = None) -> str: + if preferred and preferred != "random": + if not self._valid_conditions or preferred in self._valid_conditions: + return preferred + if self._valid_conditions: + for _ in range(8): + candidate = self._sample_condition() + if candidate in self._valid_conditions: + return candidate + if self.audio_condition_default in self._valid_conditions: + return self.audio_condition_default + return next(iter(self._valid_conditions)) + if preferred and preferred != "random": + return preferred + return self.audio_condition_default + + def _format_chat_prompt(self, content: str) -> str: + clean = (content or "").strip() + if not clean: + clean = "Please describe the visual content." + return ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{clean}\n" + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + ) + + def _prepare_block_schedule( + self, + total_tokens: int, + steps: int, + block_length: int, + ) -> Tuple[int, int, int]: + total = max(1, int(total_tokens)) + blk = max(1, min(int(block_length), total)) + if total % blk != 0: + blk = math.gcd(total, blk) + blk = blk if blk > 0 else total + if total % blk != 0: + blk = total + num_blocks = max(1, total // blk) + steps = max(num_blocks, int(steps)) + if steps % num_blocks != 0: + steps = num_blocks * math.ceil(steps / num_blocks) + return total, steps, blk + + def _prepare_video_path(self, video_input: Any) -> Tuple[Optional[str], bool]: + """Normalize Gradio video inputs (upload/webcam) to an MP4 filepath.""" + + candidate = None + if isinstance(video_input, str): + candidate = video_input + elif isinstance(video_input, dict): + candidate = ( + video_input.get("video") + or video_input.get("name") + or video_input.get("path") + ) + elif isinstance(video_input, (list, tuple)) and video_input: + candidate = str(video_input[0]) + + if not candidate: + return None, False + + candidate = str(candidate) + if not self._ensure_file_ready(candidate): + return None, False + + if candidate.lower().endswith(".mp4"): + return candidate, False + + converted = self._convert_to_mp4(candidate) + if converted: + return converted, True + suffix = Path(candidate).suffix or ".webm" + fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) + os.close(fd) + try: + shutil.copy2(candidate, tmp_path) + self._temp_video_files.append(tmp_path) + return tmp_path, False + except OSError: + try: + os.remove(tmp_path) + except OSError: + pass + return candidate, False + + if candidate.lower().endswith(".mp4"): + return candidate, False + + converted = self._convert_to_mp4(candidate) + if converted: + return converted, True + suffix = Path(candidate).suffix or ".webm" + fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_raw_", suffix=suffix) + os.close(fd) + try: + shutil.copy2(candidate, tmp_path) + self._temp_video_files.append(tmp_path) + return tmp_path, False + except OSError: + try: + os.remove(tmp_path) + except OSError: + pass + return candidate, False + + def _ensure_file_ready(self, path: str, retries: int = 8, delay: float = 0.2) -> bool: + """Ensure the uploaded/recorded file is fully written before processing.""" + + prev_size = -1 + for _ in range(retries): + try: + size = os.path.getsize(path) + except OSError: + size = -1 + if size <= 0: + time.sleep(delay) + continue + if size == prev_size: + return True + prev_size = size + time.sleep(delay) + return prev_size > 0 + + def _convert_to_mp4(self, src_path: str) -> Optional[str]: + """Convert arbitrary video file to MP4 using OpenCV (drops audio).""" + + cap = cv2.VideoCapture(src_path) + if not cap.isOpened(): + return None + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) + fps = cap.get(cv2.CAP_PROP_FPS) + + if width <= 0 or height <= 0: + cap.release() + return None + + if not fps or np.isnan(fps) or fps <= 0: + fps = 24.0 + + fd, tmp_path = tempfile.mkstemp(prefix="omada_v2t_", suffix=".mp4") + os.close(fd) + + writer = cv2.VideoWriter( + tmp_path, + cv2.VideoWriter_fourcc(*"mp4v"), + float(fps), + (width, height), + ) + if not writer.isOpened(): + cap.release() + try: + os.remove(tmp_path) + except OSError: + pass + return None + + frame_count = 0 + try: + while True: + ret, frame = cap.read() + if not ret: + break + writer.write(frame) + frame_count += 1 + finally: + cap.release() + writer.release() + + if frame_count == 0: + try: + os.remove(tmp_path) + except OSError: + pass + return None + + self._temp_video_files.append(tmp_path) + return tmp_path + + def _extract_video_tokens(self, video_path: str) -> torch.Tensor: + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if total_frames <= 0: + cap.release() + raise RuntimeError(f"No readable frames in {video_path}") + + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for idx in range(total_frames): + ret, frame = cap.read() + if idx in indices and ret: + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil = Image.fromarray(rgb) + frames.append(image_transform(pil, resolution=self.train_cfg.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) == 0: + raise RuntimeError("Failed to sample frames for V2T inference.") + + video_tensor = torch.stack(frames).to(self.device) + video_tokens = self.vq_image.get_code(video_tensor) + self.text_vocab_size + return video_tokens.long().to(self.device).view(1, -1) + + def _prepare_image_tokens(self, image: Image.Image) -> torch.LongTensor: + if image is None: + raise ValueError("Image input is required.") + tensor = image_transform(image, resolution=self.image_resolution) + tensor = tensor.unsqueeze(0).to(self.device) + codes = self.vq_image.get_code(tensor) + self.text_vocab_size + return codes.long().to(self.device) + + def _decode_image_tokens(self, tokens: torch.Tensor) -> Image.Image: + codes = tokens.view(1, -1).clamp(min=0, max=self.codebook_size - 1).to(self.device) + with torch.no_grad(): + image_tensor = self.vq_image.decode_code(codes) + image_tensor = image_tensor.squeeze(0).cpu() + image_tensor = torch.clamp((image_tensor.float() + 1.0) / 2.0, min=0.0, max=1.0) + array = (image_tensor.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) + return Image.fromarray(array) + + def _mmu_answer( + self, + images: List[Image.Image], + question: str, + max_new_tokens: Optional[int] = None, + steps: Optional[int] = None, + block_length: Optional[int] = None, + temperature: Optional[float] = None, + ) -> Tuple[str, str]: + if not images: + return "", "Please provide at least one image." + + encoded_images: List[torch.Tensor] = [] + for image in images: + if image is None: + continue + try: + tokens = self._prepare_image_tokens(image).view(-1) + encoded_images.append(tokens) + except Exception: + continue + + if not encoded_images: + return "", "Failed to encode the provided image(s)." + + question = (question or "").strip() or "Describe the visual content." + prompt = self._format_chat_prompt(question) + try: + tokenized = self.uni_prompting.text_tokenizer( + [prompt], + add_special_tokens=False, + )["input_ids"][0] + except Exception as exc: + return "", f"Failed to tokenize question: {exc}" + + try: + mmu_input_ids, prompt_masks, _ = self.uni_prompting.mmu_mult_prompt( + batch_image_ids_list=[encoded_images], + batch_text_ids=[tokenized], + ) + except Exception as exc: + return "", f"Failed to construct MMU prompt: {exc}" + + mmu_input_ids = mmu_input_ids.to(self.device) + prompt_masks = prompt_masks.to(self.device) + + answer_tokens = int((prompt_masks == 0).sum(dim=1).max().item()) + default_budget = max(1, answer_tokens) if answer_tokens > 0 else min(self.max_text_len, 256) + gen_tokens = int(max_new_tokens or default_budget) + requested_steps = steps if steps is not None else gen_tokens + requested_block = block_length if block_length is not None else max(1, gen_tokens // 2) + gen_tokens, steps, block_length = self._prepare_block_schedule( + gen_tokens, + requested_steps, + requested_block, + ) + temperature = float(temperature if temperature is not None else 0.7) + + if gen_tokens > 0: + mask_block = torch.full( + (mmu_input_ids.size(0), gen_tokens), + self.mask_token_id, + dtype=torch.long, + device=self.device, + ) + mmu_input_ids = torch.cat([mmu_input_ids, mask_block], dim=1) + + with torch.no_grad(): + output_ids = self.model.mmu_generate( + mmu_input_ids, + max_new_tokens=int(gen_tokens), + steps=int(steps), + block_length=int(block_length), + temperature=temperature, + remasking="low_confidence", + mask_id=self.mask_token_id, + ) + + decoded = self.uni_prompting.text_tokenizer.batch_decode( + output_ids[:, mmu_input_ids.shape[1]:], + skip_special_tokens=True, + )[0].strip() + if not decoded: + return "", "MMU response was empty." + return decoded, "Image understanding succeeded." + + def _generate_text_tokens( + self, + prompt_ids: torch.Tensor, + max_new_tokens: int, + steps: int, + block_length: int, + temperature: float, + cfg_scale: float = 0.0, + attention_mask: Optional[torch.Tensor] = None, + remasking: str = "low_confidence", + ) -> torch.Tensor: + prompt_ids = prompt_ids.to(self.device) + batch_size, prompt_len = prompt_ids.shape + + gen_len, steps, block_length = self._prepare_block_schedule( + max_new_tokens, + steps, + block_length, + ) + + work = torch.full( + (batch_size, prompt_len + gen_len), + self.mask_token_id, + dtype=torch.long, + device=self.device, + ) + work[:, :prompt_len] = prompt_ids + + prompt_index = work != self.mask_token_id + + attention_bias = None + if attention_mask is not None and (attention_mask == 0).any(): + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + + num_blocks = max(1, gen_len // block_length) + inner_steps = max(1, steps // num_blocks) + + for block_idx in range(num_blocks): + block_slice = slice(prompt_len + block_idx * block_length, prompt_len + (block_idx + 1) * block_length) + block_mask_index = work[:, block_slice] == self.mask_token_id + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, inner_steps) + + for inner_step in range(inner_steps): + mask_index = work == self.mask_token_id + + if cfg_scale > 0.0: + unconditional = work.clone() + unconditional[prompt_index] = self.mask_token_id + model_input = torch.cat([work, unconditional], dim=0) + logits = self.model(model_input).logits + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits) + else: + logits = self.model(work, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) + + if remasking == "low_confidence": + probs = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) + elif remasking == "random": + x0_p = torch.rand_like(x0, dtype=torch.float64) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt_len + (block_idx + 1) * block_length :] = -np.inf + + x0 = torch.where(mask_index, x0, work) + confidence = torch.where(mask_index, x0_p, torch.full_like(x0_p, float('-inf'))) + + transfer_index = torch.zeros_like(work, dtype=torch.bool) + for b in range(batch_size): + k = int(num_transfer_tokens[b, inner_step].item()) + if k <= 0: + continue + values, select_idx = torch.topk(confidence[b], k=k) + transfer_index[b, select_idx] = values != float('-inf') + + work[transfer_index] = x0[transfer_index] + + return work + +def build_demo(app: OmadaDemo, share: bool, server_name: str, server_port: Optional[int]): + theme = gr.themes.Soft(primary_hue="blue", neutral_hue="gray") + with gr.Blocks(title="AIDAS Lab @ SNU", css=CUSTOM_CSS, theme=theme, js=FORCE_LIGHT_MODE_JS) as demo: + with gr.Column(elem_classes=["omada-header"]): + if LOGO_DATA_URI: + gr.HTML( + f"
\"AIDAS
" + ) + elif LOGO_PATH.exists(): + gr.Image( + value=str(LOGO_PATH), + show_label=False, + height=140, + interactive=False, + elem_classes=["omada-logo"], + ) + gr.Markdown( + "## Omni-modal Diffusion Foundation Model\n### Pretrained Demo", + elem_classes=["omada-title"], + ) + gr.HTML( + "

" + "Create speech, " + "transcribe audio, " + "describe video, " + "chat with text, and " + "generate or edit images — all from a single model. " + "Use the advanced sections when you want tighter control." + "

") + + group_to_modes = { + "Any → Speech": ["Text → Speech", "Speech → Speech", "Video → Speech", "Image → Speech"], + "Any → Text": ["Speech → Text", "Video → Text", "Chat", "MMU (2 Images → Text)"], + "Image Generation": ["Text → Image", "Image Editing"], + } + default_group = "Any → Speech" + default_mode = group_to_modes[default_group][0] + placeholder_map = { + "Text → Speech": "Type the speech you want to generate...", + "Speech → Speech": "Optionally add context for the reply...", + "Video → Speech": "Upload video on the right. Optionally provide guidance here.", + "Image → Speech": "Upload an image on the right and add guidance if needed.", + "Speech → Text": "Upload audio on the right, then leave notes here if needed.", + "Video → Text": "Upload video on the right, then leave notes here if needed.", + "Chat": "Ask anything and the assistant will reply with text.", + "MMU (2 Images → Text)": "Ask a question about the two uploaded images.", + "Text → Image": "Describe the image you want to generate...", + "Image Editing": "Describe how you want to edit the uploaded image...", + } + with gr.Row(elem_classes=["omada-layout"], equal_height=False): + with gr.Column(scale=3, min_width=480, elem_classes=["omada-chat-column"]): + chatbox = gr.Chatbot(label="Session", height=760, sanitize_html=False) + chat_input = gr.Textbox( + label="Message", + placeholder=placeholder_map[default_mode], + lines=3, + ) + with gr.Row(): + send_button = gr.Button("Send", variant="primary") + clear_button = gr.Button("Clear", variant="secondary") + with gr.Column(scale=2, min_width=360, elem_classes=["omada-controls"]): + task_group_selector = gr.Radio( + list(group_to_modes.keys()), + value=default_group, + label="Task Group", + elem_classes=["omada-task-selector", "omada-task-group-buttons"], + ) + submode_selector = gr.Radio( + group_to_modes[default_group], + value=default_mode, + label="Task Mode", + elem_classes=["omada-task-selector", "omada-task-mode-buttons"], + ) + with gr.Column(visible=True, elem_classes=["omada-mode-panel"]) as t2s_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Text-to-Speech Controls") + with gr.Group(elem_classes=["omada-advanced"]): + gr.Markdown("**Generation**") + with gr.Row(): + t2s_max_tokens = gr.Slider(2, 512, value=384, label="Speech token length", step=2) + t2s_steps = gr.Slider(2, 512, value=128, label="Total refinement steps", step=2) + with gr.Row(): + t2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2) + t2s_cfg = gr.Slider(0.0, 6.0, value=3.5, label="CFG scale", step=0.1) + t2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) + with gr.Group(elem_classes=["omada-advanced"]): + gr.Markdown("**Voice styling**") + with gr.Row(): + t2s_gender = gr.Dropdown(['random'] + app.genders, value='random', label="Voice gender") + t2s_emotion = gr.Dropdown(['random'] + app.emotions, value='random', label="Emotion") + with gr.Row(): + t2s_speed = gr.Dropdown(['random'] + app.speeds, value='random', label="Speaking speed") + t2s_pitch = gr.Dropdown(['random'] + app.pitches, value='random', label="Pitch") + if T2S_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample prompts**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=T2S_EXAMPLES, + inputs=[chat_input], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2s_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Speech-to-Speech Controls") + s2s_audio = gr.Audio(type="filepath", label="Source speech", sources=["microphone", "upload"]) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + s2s_max_tokens = gr.Slider(2, 512, value=256, label="Reply token length", step=2) + with gr.Row(): + s2s_steps = gr.Slider(2, 512, value=128, label="Refinement steps", step=2) + s2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2) + with gr.Row(): + s2s_temperature = gr.Slider(0.0, 2.0, value=0.0, label="Sampling temperature", step=0.05) + s2s_cfg = gr.Slider(0.0, 6.0, value=4.0, label="CFG scale", step=0.1) + if S2S_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample S2S clips**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=S2S_EXAMPLES, + inputs=[s2s_audio], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as s2t_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Speech-to-Text Controls") + s2t_audio = gr.Audio(type="filepath", label="Speech input", sources=["microphone", "upload"]) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + with gr.Row(): + s2t_steps = gr.Slider(2, 512, value=128, label="Denoising steps", step=2) + s2t_block = gr.Slider(2, 512, value=128, label="Block length", step=2) + s2t_max_tokens = gr.Slider(2, 512, value=128, label="Max tokens", step=2) + s2t_remasking = gr.Dropdown( + choices=["low_confidence", "random"], + value="low_confidence", + label="Remasking strategy", + ) + if S2T_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample clips**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=S2T_EXAMPLES, + inputs=[s2t_audio], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2t_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Video-to-Text Controls") + v2t_video = gr.Video( + label="Upload or record video", + format=None, + height=256, + sources=["upload", "webcam"], + ) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + with gr.Row(): + v2t_steps = gr.Slider(2, 512, value=64, label="Denoising steps", step=2) + v2t_block = gr.Slider(2, 512, value=64, label="Block length", step=2) + v2t_max_tokens = gr.Slider(2, 512, value=64, label="Max tokens", step=2) + if V2T_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample videos**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=V2T_EXAMPLES, + inputs=[v2t_video], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as v2s_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Video-to-Speech Controls") + v2s_video = gr.Video( + label="Upload or record video", + format=None, + height=256, + sources=["upload", "webcam"], + ) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + v2s_max_tokens = gr.Slider(2, 512, value=256, label="Reply token length", step=2) + with gr.Row(): + v2s_steps = gr.Slider(2, 512, value=128, label="Refinement steps", step=2) + v2s_block = gr.Slider(2, 512, value=128, label="Block length", step=2) + with gr.Row(): + v2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) + v2s_cfg = gr.Slider(0.0, 6.0, value=3.0, label="CFG scale", step=0.1) + if V2S_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample videos**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=V2S_EXAMPLES, + inputs=[v2s_video], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as i2s_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Image-to-Speech Controls") + i2s_image = gr.Image(type="pil", label="Upload image", sources=["upload"]) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + i2s_max_tokens = gr.Slider(2, 512, value=256, label="Reply token length", step=2) + with gr.Row(): + i2s_steps = gr.Slider(2, 512, value=256, label="Refinement steps", step=2) + i2s_block = gr.Slider(2, 512, value=256, label="Block length", step=2) + with gr.Row(): + i2s_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) + i2s_cfg = gr.Slider(0.0, 6.0, value=3.0, label="CFG scale", step=0.1) + if I2S_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample images**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=I2S_EXAMPLES, + inputs=[i2s_image], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as image_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Image Tasks") + image_mode_selector = gr.Radio( + ["Generation", "Editing"], + value="Generation", + label="Sub-mode", + ) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"], visible=True) as t2i_settings: + t2i_timesteps = gr.Slider(4, 128, value=32, label="Timesteps", step=2) + t2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) + t2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1) + with gr.Accordion("Editing settings", open=True, elem_classes=["omada-advanced"], visible=False) as i2i_settings: + i2i_image = gr.Image(type="pil", label="Reference image", sources=["upload"]) + i2i_timesteps = gr.Slider(4, 128, value=18, label="Timesteps", step=2) + i2i_temperature = gr.Slider(0.0, 2.0, value=1.0, label="Sampling temperature", step=0.05) + i2i_guidance = gr.Slider(0.0, 8.0, value=3.5, label="CFG scale", step=0.1) + if T2I_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample prompts**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=T2I_EXAMPLES, + inputs=[chat_input], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as chat_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Chat Controls") + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + chat_max_tokens = gr.Slider(2, 512, value=64, label="Reply max tokens", step=2) + with gr.Row(): + chat_steps = gr.Slider(2, 512, value=64, label="Refinement steps", step=2) + chat_block = gr.Slider(2, 512, value=64, label="Block length", step=2) + chat_temperature = gr.Slider(0.0, 2.0, value=0.8, label="Sampling temperature", step=0.05) + if CHAT_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample prompts**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=CHAT_EXAMPLES, + inputs=[chat_input], + examples_per_page=4, + ) + with gr.Column(visible=False, elem_classes=["omada-mode-panel"]) as mmu_panel: + with gr.Group(elem_classes=["omada-card"]): + gr.Markdown("### Multi-image Reasoning") + mmu_image_a = gr.Image(type="pil", label="Image A", sources=["upload"]) + mmu_image_b = gr.Image(type="pil", label="Image B", sources=["upload"]) + with gr.Accordion("Generation settings", open=True, elem_classes=["omada-advanced"]): + mmu_max_tokens = gr.Slider(2, 512, value=256, label="Answer max tokens", step=2) + with gr.Row(): + mmu_steps = gr.Slider(2, 512, value=256, label="Refinement steps", step=2) + mmu_block = gr.Slider(2, 512, value=128, label="Block length", step=2) + mmu_temperature = gr.Slider(0.0, 2.0, value=0.7, label="Sampling temperature", step=0.05) + if MMU_EXAMPLES: + with gr.Group(elem_classes=["omada-card", "omada-examples-card"]): + gr.Markdown("**Sample prompts**") + with gr.Column(elem_classes=["omada-examples"]): + gr.Examples( + examples=MMU_EXAMPLES, + inputs=[mmu_image_a, mmu_image_b, chat_input], + examples_per_page=1, + ) + + def _panel_updates(group: str, mode: str): + show_t2s = group == "Any → Speech" and mode == "Text → Speech" + show_s2s = group == "Any → Speech" and mode == "Speech → Speech" + show_v2s = group == "Any → Speech" and mode == "Video → Speech" + show_i2s = group == "Any → Speech" and mode == "Image → Speech" + show_s2t = group == "Any → Text" and mode == "Speech → Text" + show_v2t = group == "Any → Text" and mode == "Video → Text" + show_chat = group == "Any → Text" and mode == "Chat" + show_mmu = group == "Any → Text" and mode == "MMU (2 Images → Text)" + show_image = group == "Image Generation" + placeholder = placeholder_map.get(mode, chat_input.placeholder) + image_mode_value = "Generation" if mode == "Text → Image" else "Editing" + t2i_visible = show_image and mode == "Text → Image" + i2i_visible = show_image and mode == "Image Editing" + image_mode_update = gr.update(value=image_mode_value) if show_image else gr.update() + return ( + gr.update(placeholder=placeholder), + gr.update(visible=show_t2s), + gr.update(visible=show_s2s), + gr.update(visible=show_v2s), + gr.update(visible=show_i2s), + gr.update(visible=show_s2t), + gr.update(visible=show_v2t), + gr.update(visible=show_chat), + gr.update(visible=show_mmu), + gr.update(visible=show_image), + image_mode_update, + gr.update(visible=t2i_visible), + gr.update(visible=i2i_visible), + ) + + def _on_group_change(group: str): + default_mode_local = group_to_modes[group][0] + submode_update = gr.update(choices=group_to_modes[group], value=default_mode_local) + panel_updates = _panel_updates(group, default_mode_local) + return (submode_update, *panel_updates) + + def _on_submode_change(mode: str, group: str): + return _panel_updates(group, mode) + + task_group_selector.change( + _on_group_change, + inputs=[task_group_selector], + outputs=[ + submode_selector, + chat_input, + t2s_panel, + s2s_panel, + v2s_panel, + i2s_panel, + s2t_panel, + v2t_panel, + chat_panel, + mmu_panel, + image_panel, + image_mode_selector, + t2i_settings, + i2i_settings, + ], + ) + + submode_selector.change( + _on_submode_change, + inputs=[submode_selector, task_group_selector], + outputs=[ + chat_input, + t2s_panel, + s2s_panel, + v2s_panel, + i2s_panel, + s2t_panel, + v2t_panel, + chat_panel, + mmu_panel, + image_panel, + image_mode_selector, + t2i_settings, + i2i_settings, + ], + ) + + def _toggle_image_task(task: str): + return ( + gr.update(visible=task == "Generation"), + gr.update(visible=task == "Editing"), + ) + + image_mode_selector.change( + _toggle_image_task, + inputs=[image_mode_selector], + outputs=[t2i_settings, i2i_settings], + ) + + def _chat_handler( + history, + message, + group, + mode, + s2t_audio_path, + v2t_video_path, + t2s_max_tokens, + t2s_steps, + t2s_block, + t2s_temperature, + t2s_cfg, + t2s_gender, + t2s_emotion, + t2s_speed, + t2s_pitch, + s2t_steps, + s2t_block, + s2t_max_tokens, + s2t_remasking, + v2t_steps, + v2t_block, + v2t_max_tokens, + s2s_audio_path, + s2s_max_tokens, + s2s_steps, + s2s_block, + s2s_temperature, + s2s_cfg, + i2s_image, + i2s_max_tokens, + i2s_steps, + i2s_block, + i2s_temperature, + i2s_cfg, + image_mode, + t2i_timesteps, + t2i_temperature, + t2i_guidance, + i2i_image, + i2i_timesteps, + i2i_temperature, + i2i_guidance, + v2s_video_path, + v2s_max_tokens, + v2s_steps, + v2s_block, + v2s_temperature, + v2s_cfg, + chat_max_tokens, + chat_steps, + chat_block, + chat_temperature, + mmu_image_a, + mmu_image_b, + mmu_max_tokens, + mmu_steps, + mmu_block, + mmu_temperature, + ): + history = history or [] + message = (message or "").strip() + response = "" + + if group == "Any → Speech": + if mode == "Text → Speech": + if not message: + status = "Please type some text for speech synthesis." + response = _render_text_message(status, "") + else: + audio_result, status = app.run_t2s( + message, + t2s_max_tokens, + t2s_steps, + t2s_block, + t2s_temperature, + t2s_cfg, + t2s_gender, + t2s_emotion, + t2s_speed, + t2s_pitch, + ) + response = _render_audio_message(status, audio_result) + display_user_raw = message or "[Speech request]" + elif mode == "Speech → Speech": + audio_result, status = app.run_s2s( + s2s_audio_path, + s2s_max_tokens, + s2s_steps, + s2s_block, + s2s_temperature, + s2s_cfg, + ) + response = _render_audio_message(status, audio_result) + display_user_raw = message or "[Speech-to-speech request]" + elif mode == "Video → Speech": + audio_result, status = app.run_v2s( + v2s_video_path, + message, + v2s_max_tokens, + v2s_steps, + v2s_block, + v2s_temperature, + v2s_cfg, + ) + response = _render_audio_message(status, audio_result) + display_user_raw = message or "[Video-to-speech request]" + else: # Image → Speech + audio_result, status = app.run_i2s( + i2s_image, + message, + i2s_max_tokens, + i2s_steps, + i2s_block, + i2s_temperature, + i2s_cfg, + ) + response = _render_audio_message(status, audio_result) + display_user_raw = message or "[Image-to-speech request]" + elif group == "Any → Text": + if mode == "Speech → Text": + if not s2t_audio_path: + status = "Please upload or record an audio clip first." + response = _render_text_message(status, "") + else: + transcript, status = app.run_s2t( + s2t_audio_path, + s2t_steps, + s2t_block, + s2t_max_tokens, + s2t_remasking, + ) + response = _render_text_message(status, transcript) + display_user_raw = message or "[Audio transcription request]" + elif mode == "Video → Text": + if not v2t_video_path: + status = "Please upload or record a video first." + response = _render_text_message(status, "") + else: + caption, status = app.run_v2t( + v2t_video_path, + v2t_steps, + v2t_block, + v2t_max_tokens, + ) + response = _render_text_message(status, caption) + display_user_raw = message or "[Video caption request]" + elif mode == "Chat": + reply, status = app.run_chat( + message, + chat_max_tokens, + chat_steps, + chat_block, + chat_temperature, + ) + response = _render_text_message(status, reply) + display_user_raw = message or "[Chat request]" + else: # MMU (2 Images → Text) + reply, status = app.run_mmu_dual( + mmu_image_a, + mmu_image_b, + message, + mmu_max_tokens, + mmu_steps, + mmu_block, + mmu_temperature, + ) + response = _render_text_message(status, reply) + display_user_raw = message or "[Multi-image question]" + else: # Image Generation + if mode == "Text → Image": + if not message: + status = "Please provide a prompt for image generation." + response = _render_text_message(status, "") + else: + image_result, status = app.run_t2i( + message, + t2i_timesteps, + t2i_temperature, + t2i_guidance, + ) + response = _render_image_message(status, image_result) + display_user_raw = message or "[Image generation request]" + else: # Image Editing + image_result, status = app.run_i2i( + message, + i2i_image, + i2i_timesteps, + i2i_temperature, + i2i_guidance, + ) + response = _render_image_message(status, image_result) + display_user_raw = message or "[Image editing request]" + + if not response: + status = f"Mode '{mode}' is not supported." + response = _render_text_message(status, "") + display_user_raw = message or "[Unsupported mode]" + + display_user = _format_user_message(display_user_raw) + history = history + [(display_user, response)] + return history, "" + + submit_inputs = [ + chatbox, + chat_input, + task_group_selector, + submode_selector, + s2t_audio, + v2t_video, + t2s_max_tokens, + t2s_steps, + t2s_block, + t2s_temperature, + t2s_cfg, + t2s_gender, + t2s_emotion, + t2s_speed, + t2s_pitch, + s2t_steps, + s2t_block, + s2t_max_tokens, + s2t_remasking, + v2t_steps, + v2t_block, + v2t_max_tokens, + s2s_audio, + s2s_max_tokens, + s2s_steps, + s2s_block, + s2s_temperature, + s2s_cfg, + i2s_image, + i2s_max_tokens, + i2s_steps, + i2s_block, + i2s_temperature, + i2s_cfg, + image_mode_selector, + t2i_timesteps, + t2i_temperature, + t2i_guidance, + i2i_image, + i2i_timesteps, + i2i_temperature, + i2i_guidance, + v2s_video, + v2s_max_tokens, + v2s_steps, + v2s_block, + v2s_temperature, + v2s_cfg, + chat_max_tokens, + chat_steps, + chat_block, + chat_temperature, + mmu_image_a, + mmu_image_b, + mmu_max_tokens, + mmu_steps, + mmu_block, + mmu_temperature, + ] + submit_outputs = [chatbox, chat_input] + + chat_input.submit(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) + send_button.click(_chat_handler, inputs=submit_inputs, outputs=submit_outputs) + + def _clear_session(): + return ( + [], + "", + gr.update(value=None), + gr.update(value=None), + gr.update(value=None), + gr.update(value=None), + gr.update(value=None), + gr.update(value=None), + gr.update(value=None), + gr.update(value=None), + ) + + clear_button.click( + _clear_session, + inputs=None, + outputs=[ + chatbox, + chat_input, + s2t_audio, + v2t_video, + s2s_audio, + i2s_image, + v2s_video, + i2i_image, + mmu_image_a, + mmu_image_b, + ], + ) + + demo.launch( + share=share, + server_name=server_name, + server_port=server_port, + ) + + +def parse_args(): + parser = argparse.ArgumentParser(description="OMada Gradio demo for audio/video tasks") + parser.add_argument("--train-config", required=True, help="Path to the training YAML used to build tokenizer + VQ modules") + parser.add_argument("--checkpoint", required=True, help="Path to an `unwrapped_model` directory") + parser.add_argument("--device", default=None, help="Override device (e.g. cuda:0). Defaults to CUDA if available") + parser.add_argument("--share", action="store_true", help="Enable public Gradio share link") + parser.add_argument("--server-name", default="0.0.0.0", help="Host address for Blocks.launch") + parser.add_argument("--server-port", type=int, default=None, help="Port for Blocks.launch") + return parser.parse_args() + + +def main(): + args = parse_args() + app = OmadaDemo(args.train_config, args.checkpoint, args.device) + build_demo(app, args.share, args.server_name, args.server_port) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/inference/run_s2t.sh b/MMaDA/inference/run_s2t.sh new file mode 100755 index 0000000000000000000000000000000000000000..d43b0738ffd56eac5bd379b809fc4e8a070bf2f7 --- /dev/null +++ b/MMaDA/inference/run_s2t.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Editable defaults (override via KEY=VALUE or extra CLI flags) +TRAIN_CONFIG=${TRAIN_CONFIG:-MMaDA/configs/omada_pretraining_stage1-3.yaml} +CKPT_ROOT=${CKPT_ROOT:-ckpts/omada/omada-training-stage1_3rd} +INFER_CONFIG=${INFER_CONFIG:-} +CHECKPOINTS=${CHECKPOINTS:-} + +# Generation params +STEPS=${STEPS:-} +BLOCK_LENGTH=${BLOCK_LENGTH:-} +MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-} +REMASKING=${REMASKING:-} +BATCH_SIZE=${BATCH_SIZE:-} + +# Dataset params +SUBSET=${SUBSET:-} +SPLIT=${SPLIT:-} +ROOT_PATH=${ROOT_PATH:-} +LIMIT=${LIMIT:-} +TEXT_NORM=${TEXT_NORM:-} + +# Allow KEY=VALUE pairs override; keep other CLI flags in REST_ARGS +REST_ARGS=() +for arg in "$@"; do + case "$arg" in + *=*) eval "$arg" ;; + *) REST_ARGS+=("$arg") ;; + esac +done + +ARGS=( + --train_config "$TRAIN_CONFIG" + --ckpt_root "$CKPT_ROOT" +) + +if [[ -n "$INFER_CONFIG" ]]; then + ARGS+=(--infer_config "$INFER_CONFIG") +fi + +[[ -n "$STEPS" ]] && ARGS+=(--steps "$STEPS") +[[ -n "$BLOCK_LENGTH" ]] && ARGS+=(--block_length "$BLOCK_LENGTH") +[[ -n "$MAX_NEW_TOKENS" ]] && ARGS+=(--max_new_tokens "$MAX_NEW_TOKENS") +[[ -n "$REMASKING" ]] && ARGS+=(--remasking "$REMASKING") +[[ -n "$BATCH_SIZE" ]] && ARGS+=(--batch_size "$BATCH_SIZE") + +[[ -n "$SUBSET" ]] && ARGS+=(--subset "$SUBSET") +[[ -n "$SPLIT" ]] && ARGS+=(--split "$SPLIT") +[[ -n "$ROOT_PATH" ]] && ARGS+=(--root_path "$ROOT_PATH") +[[ -n "$LIMIT" ]] && ARGS+=(--limit "$LIMIT") +[[ -n "$TEXT_NORM" ]] && ARGS+=(--text_norm "$TEXT_NORM") + +if [[ -n "$CHECKPOINTS" ]]; then + IFS=',' read -r -a CK_ARR <<< "$CHECKPOINTS" + for c in "${CK_ARR[@]}"; do + ARGS+=(--checkpoint "$c") + done +fi + +python -u MMaDA/inference/s2t_infer.py "${ARGS[@]}" "${REST_ARGS[@]}" + +# Example: +# bash MMaDA/inference/run_s2t.sh STEPS=256 BLOCK_LENGTH=256 MAX_NEW_TOKENS=256 BATCH_SIZE=4 ROOT_PATH=/home/work/AIDAS/data/audio/LibriSpeech/test-clean TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml CKPT_ROOT=ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model + +# +# for step in {100000..310000..10000}; do ckpt="ckpts/omada/omada-training-stage1_3rd/checkpoint-$step/unwrapped_model"; if [[ -d "$ckpt" ]]; then echo "Running S2T checkpoint-$step"; bash MMaDA/inference/run_s2t.sh TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml CHECKPOINTS="$ckpt" STEPS=256 BLOCK_LENGTH=256 MAX_NEW_TOKENS=512 BATCH_SIZE=4; else echo "Skip missing $ckpt"; fi; done diff --git a/MMaDA/inference/run_t2s.sh b/MMaDA/inference/run_t2s.sh new file mode 100755 index 0000000000000000000000000000000000000000..d73bb07e9a90c561dab0cc84144bec64e24fd8bd --- /dev/null +++ b/MMaDA/inference/run_t2s.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Editable defaults (can be overridden by KEY=VALUE pairs or by passing CLI flags) +TRAIN_CONFIG=${TRAIN_CONFIG:-MMaDA/configs/omada_pretraining_stage1-3.yaml} +CKPT_ROOT=${CKPT_ROOT:-ckpts/omada/omada-training-stage1_3rd} +INFER_CONFIG=${INFER_CONFIG:-} +CHECKPOINTS=${CHECKPOINTS:-} + +# Generation params +MODE=${MODE:-} +GUIDANCE_SCALE=${GUIDANCE_SCALE:-} +TEMPERATURE=${TEMPERATURE:-} +TIMESTEPS=${TIMESTEPS:-} +SEQ_LEN=${SEQ_LEN:-} +NOISE_SCHEDULE=${NOISE_SCHEDULE:-} +NOISE_TYPE=${NOISE_TYPE:-} +BATCH_SIZE=${BATCH_SIZE:-} +OUTPUT_DIR=${OUTPUT_DIR:-} +WER_ASR_MODEL=${WER_ASR_MODEL:-} +WER_LANGUAGE=${WER_LANGUAGE:-} +WER_MAX_SAMPLES=${WER_MAX_SAMPLES:-} +TEXT_NORM=${TEXT_NORM:-} +BLOCK_LENGTH=${BLOCK_LENGTH:-} +MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-} +AUDIO_CODEBOOK_SIZE=${AUDIO_CODEBOOK_SIZE:-} + +# Aliases for convenience +CFG=${CFG:-} +TIMESTEP=${TIMESTEP:-} +WER=${WER:-} +ASR=${ASR:-} +LANG=${LANG:-} +WER_SAMPLES=${WER_SAMPLES:-} + +# Dataset params +SUBSET=${SUBSET:-} +SPLIT=${SPLIT:-} +LIMIT=${LIMIT:-} + +REST_ARGS=() +for arg in "$@"; do + case "$arg" in + *=*) eval "$arg" ;; + *) REST_ARGS+=("$arg") ;; + esac +done + +# Apply alias variables if provided +[[ -n "$CFG" ]] && TRAIN_CONFIG="$CFG" +[[ -n "$TIMESTEP" ]] && TIMESTEPS="$TIMESTEP" + +# Enable WER with defaults when WER toggle is set +if [[ -n "$WER" && "$WER" != "0" ]]; then + : "${WER_ASR_MODEL:=${ASR:-openai/whisper-large-v3}}" + : "${WER_LANGUAGE:=${LANG:-english}}" + : "${WER_MAX_SAMPLES:=${WER_SAMPLES:-64}}" +fi + +ARGS=( + --train_config "$TRAIN_CONFIG" + --ckpt_root "$CKPT_ROOT" +) + +if [[ -n "$INFER_CONFIG" ]]; then + ARGS+=(--infer_config "$INFER_CONFIG") +fi + +if [[ -n "$CHECKPOINTS" ]]; then + IFS=',' read -r -a CK_ARR <<< "$CHECKPOINTS" + for c in "${CK_ARR[@]}"; do + ARGS+=(--checkpoint "$c") + done +fi + +[[ -n "$MODE" ]] && ARGS+=(--mode "$MODE") +[[ -n "$GUIDANCE_SCALE" ]] && ARGS+=(--guidance_scale "$GUIDANCE_SCALE") +[[ -n "$TEMPERATURE" ]] && ARGS+=(--temperature "$TEMPERATURE") +[[ -n "$TIMESTEPS" ]] && ARGS+=(--timesteps "$TIMESTEPS") +[[ -n "$SEQ_LEN" ]] && ARGS+=(--seq_len "$SEQ_LEN") +[[ -n "$NOISE_SCHEDULE" ]] && ARGS+=(--noise_schedule "$NOISE_SCHEDULE") +[[ -n "$NOISE_TYPE" ]] && ARGS+=(--noise_type "$NOISE_TYPE") +[[ -n "$BATCH_SIZE" ]] && ARGS+=(--batch_size "$BATCH_SIZE") +[[ -n "$OUTPUT_DIR" ]] && ARGS+=(--output_dir "$OUTPUT_DIR") +[[ -n "$WER_ASR_MODEL" ]] && ARGS+=(--wer_asr_model "$WER_ASR_MODEL") +[[ -n "$WER_LANGUAGE" ]] && ARGS+=(--wer_language "$WER_LANGUAGE") +[[ -n "$WER_MAX_SAMPLES" ]] && ARGS+=(--wer_max_samples "$WER_MAX_SAMPLES") +[[ -n "$TEXT_NORM" ]] && ARGS+=(--text_norm "$TEXT_NORM") +[[ -n "$BLOCK_LENGTH" ]] && ARGS+=(--block_length "$BLOCK_LENGTH") +[[ -n "$MAX_NEW_TOKENS" ]] && ARGS+=(--max_new_tokens "$MAX_NEW_TOKENS") +[[ -n "$AUDIO_CODEBOOK_SIZE" ]] && ARGS+=(--audio_codebook_size "$AUDIO_CODEBOOK_SIZE") + +[[ -n "$SUBSET" ]] && ARGS+=(--subset "$SUBSET") +[[ -n "$SPLIT" ]] && ARGS+=(--split "$SPLIT") +[[ -n "$LIMIT" ]] && ARGS+=(--limit "$LIMIT") + +python -u MMaDA/inference/t2s_infer.py "${ARGS[@]}" "${REST_ARGS[@]}" + +# Example: +# bash MMaDA/inference/run_t2s.sh MODE=free GUIDANCE_SCALE=1.5 TIMESTEPS=60 SEQ_LEN=786 BATCH_SIZE=1 OUTPUT_DIR=inference/outputs/t2s_cli TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml CKPT_ROOT=ckpts/omada/omada-training-stage1_3rd/checkpoint-315000/unwrapped_model + + bash MMaDA/inference/run_t2s.sh \ + WER=1 ASR=openai/whisper-large-v3 LANG=english TEXT_NORM=basic WER_SAMPLES=256 \ + TIMESTEPS=256 BLOCKSIZE=256 MODE=mmu GUIDANCE_SCALE=3.0 SEQ_LEN=512 BATCH_SIZE=4 \ + OUTPUT_DIR=inference/outputs/t2s_cli \ + TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml \ + CKPT_ROOT=ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model diff --git a/MMaDA/inference/run_t2s_grid_timesteps.sh b/MMaDA/inference/run_t2s_grid_timesteps.sh new file mode 100644 index 0000000000000000000000000000000000000000..76277234970dbb9ba38df8760fba8700e90bf9c3 --- /dev/null +++ b/MMaDA/inference/run_t2s_grid_timesteps.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Grid runner: sweep TIMESTEPS over powers of two (2..256) + +# Default grid (space-separated). Override with TIMESTEPS_GRID="2 4 ..." if needed. +TIMESTEPS_GRID=${TIMESTEPS_GRID:-"2 4 8 16 32 64 128 256"} + +# Pass all KEY=VALUE and CLI flags through to the base runner. +REST_ARGS=("$@") + +for TS in $TIMESTEPS_GRID; do + echo "[t2s-grid] Running TIMESTEPS=$TS" + bash MMaDA/inference/run_t2s.sh TIMESTEPS="$TS" "${REST_ARGS[@]}" +done + +# Examples: + # bash MMaDA/inference/run_t2s_grid_timesteps.sh \ + # WER=1 ASR=openai/whisper-large-v3 LANG=english TEXT_NORM=basic WER_SAMPLES=128 \ + # MODE=free GUIDANCE_SCALE=1.5 SEQ_LEN=512 BATCH_SIZE=1 \ + # OUTPUT_DIR=inference/outputs/t2s_cli \ + # TRAIN_CONFIG=MMaDA/configs/omada_pretraining_stage1-3.yaml \ + # CKPT_ROOT=ckpts/omada/omada-training-stage1_4th/checkpoint-330000/unwrapped_model diff --git a/MMaDA/inference/run_v2t.sh b/MMaDA/inference/run_v2t.sh new file mode 100755 index 0000000000000000000000000000000000000000..1c827702076ee77596c7df09fcd3be741bdc8005 --- /dev/null +++ b/MMaDA/inference/run_v2t.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Editable defaults (override via KEY=VALUE or extra CLI flags) +TRAIN_CONFIG=${TRAIN_CONFIG:-MMaDA/configs/omada_pretraining_stage1-3.yaml} +CKPT_ROOT=${CKPT_ROOT:-ckpts/omada/omada-training-stage1_3rd} +INFER_CONFIG=${INFER_CONFIG:-} +CHECKPOINTS=${CHECKPOINTS:-} + +# Generation params +STEPS=${STEPS:-} +BLOCK_LENGTH=${BLOCK_LENGTH:-} +MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-} + +# Dataset params +VIDEO_DIR=${VIDEO_DIR:-} +QUESTIONS=${QUESTIONS:-} # comma-separated list, e.g. "Q1,Q2" + +REST_ARGS=() +for arg in "$@"; do + case "$arg" in + *=*) + key=${arg%%=*} + val=${arg#*=} + printf -v "$key" "%s" "$val" + ;; + *) REST_ARGS+=("$arg") ;; + esac +done + +ARGS=( + --train_config "$TRAIN_CONFIG" + --ckpt_root "$CKPT_ROOT" +) + +if [[ -n "$INFER_CONFIG" ]]; then + ARGS+=(--infer_config "$INFER_CONFIG") +fi + +[[ -n "$STEPS" ]] && ARGS+=(--steps "$STEPS") +[[ -n "$BLOCK_LENGTH" ]] && ARGS+=(--block_length "$BLOCK_LENGTH") +[[ -n "$MAX_NEW_TOKENS" ]] && ARGS+=(--max_new_tokens "$MAX_NEW_TOKENS") +[[ -n "$VIDEO_DIR" ]] && ARGS+=(--video_dir "$VIDEO_DIR") + +if [[ -n "$QUESTIONS" ]]; then + IFS=',' read -r -a Q_ARR <<< "$QUESTIONS" + for q in "${Q_ARR[@]}"; do + ARGS+=(--question "$q") + done +fi + +if [[ -n "$CHECKPOINTS" ]]; then + IFS=',' read -r -a CK_ARR <<< "$CHECKPOINTS" + for c in "${CK_ARR[@]}"; do + ARGS+=(--checkpoint "$c") + done +fi + +python -u MMaDA/inference/v2t_infer.py "${ARGS[@]}" "${REST_ARGS[@]}" diff --git a/MMaDA/inference/run_v2t_315k.sh b/MMaDA/inference/run_v2t_315k.sh new file mode 100644 index 0000000000000000000000000000000000000000..4fa6a4b3e63f6baf5c9cae71232373b57ca37a71 --- /dev/null +++ b/MMaDA/inference/run_v2t_315k.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Wrapper to run V2T inference at checkpoint-315000 + +TRAIN_CONFIG=${TRAIN_CONFIG:-MMaDA/configs/omada_pretraining_stage1-3.yaml} +CHECKPOINTS=${CHECKPOINTS:-ckpts/omada/omada-training-stage1_3rd/checkpoint-315000/unwrapped_model} + +# Optional overrides +VIDEO_DIR=${VIDEO_DIR:-/home/work/AIDAS/video/demo} +STEPS=${STEPS:-256} +BLOCK_LENGTH=${BLOCK_LENGTH:-128} +MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-256} +QUESTIONS=${QUESTIONS:-"Please provide a detailed description of the video."} + +exec bash MMaDA/inference/run_v2t.sh \ + TRAIN_CONFIG="$TRAIN_CONFIG" \ + CHECKPOINTS="$CHECKPOINTS" \ + VIDEO_DIR="$VIDEO_DIR" \ + STEPS="$STEPS" \ + BLOCK_LENGTH="$BLOCK_LENGTH" \ + MAX_NEW_TOKENS="$MAX_NEW_TOKENS" \ + --question "$QUESTIONS" \ + "$@" + +# Usage examples: +# bash MMaDA/inference/run_v2t_315k.sh +# bash MMaDA/inference/run_v2t_315k.sh VIDEO_DIR=/path/to/videos QUESTIONS="Describe the video.,What happens next?" diff --git a/MMaDA/inference/s2t_infer.py b/MMaDA/inference/s2t_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..56f0ae1bba57360c1f240a6e05b6257e55eccca3 --- /dev/null +++ b/MMaDA/inference/s2t_infer.py @@ -0,0 +1,350 @@ +import os +import argparse +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from functools import partial +from typing import Callable, List +import re + +import torch +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset +import wandb + +from omegaconf import OmegaConf + +from training.data import S2T_INSTRUCTION +from inference.common import ( + load_train_config, + get_vq_model_audio, + build_uni_prompting, + load_omada_from_checkpoint, + list_checkpoints, + grid_dict, + init_wandb, + safe_log_table, +) + +_ANGLE_TOKEN_RE = re.compile(r"<[^>]+>") +_EXCLAMATIONPOINT_RE = re.compile(r"EXCLAMATIONPOINT", flags=re.IGNORECASE) +_PUNCT_RE = re.compile(r"[^\w\s']") + + +def _strip_custom_markers(text: str) -> str: + had_exclamationpoint = bool(_EXCLAMATIONPOINT_RE.search(text)) + text = _ANGLE_TOKEN_RE.sub(" ", text) + if had_exclamationpoint: + text = _EXCLAMATIONPOINT_RE.sub(" ", text) + if had_exclamationpoint: + text = text.replace(".", "") + text = _PUNCT_RE.sub(" ", text) + text = re.sub(r"\s+", " ", text).strip() + return text + + +def _basic_normalize(text: str) -> str: + text = _strip_custom_markers(text) + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + text = re.sub(r"\s+", " ", text).strip() + return text + + +def build_normalize_fn(mode: str) -> Callable[[str], str]: + mode = (mode or "basic").strip().lower() + if mode in {"off", "none", "no"}: + return lambda s: s + if mode in {"english", "whisper", "whisper_en"}: + try: + from normalizer.normalizer import EnglishTextNormalizer + + n = EnglishTextNormalizer() + + def _fn(s: str) -> str: + return re.sub(r"\s+", " ", n(s)).strip() + + return _fn + except Exception: + # Fallback to basic if normalizer package import fails + return _basic_normalize + # default basic + return _basic_normalize + + +def calculate_wer(predictions: List[str], references: List[str], normalize: Callable[[str], str] = _basic_normalize): + import editdistance + # Normalize texts before WER + predictions = [normalize(p) for p in predictions] + references = [normalize(r) for r in references] + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path: str): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + ex = self.hf_dataset[idx] + sample_id = ex["id"] + speaker_id, chapter_id, _ = sample_id.split("-") + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + return {"audio_path": audio_path, "gt_text": ex["text"], "sample_id": sample_id} + + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, cfg): + import random + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + cfg.model.omada.codebook_size + for item in batch: + path = item['audio_path'] + tokens = vq_model_audio.encode(path) + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_seq = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_seq.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device, + ) + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch], + } + + +def run_once(ckpt_path: str, hparams: dict, train_cfg, device): + # Models and prompting + uni_prompting, tokenizer = build_uni_prompting(train_cfg) + vq_audio = get_vq_model_audio(train_cfg, device) + model = load_omada_from_checkpoint(ckpt_path, device) + + # Dataset + dcfg = hparams.get("dataset", {}) + subset = dcfg.get("subset", "clean") + split = dcfg.get("split", "test") + limit = int(dcfg.get("limit", 128)) + root_path = dcfg.get("root_path", "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + ds_raw = load_dataset("librispeech_asr", subset, split=split) + if limit > 0: + ds_raw = ds_raw.select(range(min(limit, len(ds_raw)))) + ds = S2TEvalDataset(ds_raw, root_path=root_path) + + collate = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + cfg=train_cfg, + ) + batch_size = int(hparams.get("batch_size", train_cfg.training.batch_size_s2t)) + loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate) + + # Generation hparams + steps = int(hparams.get("steps", 128)) + block_length = int(hparams.get("block_length", 64)) + max_new_tokens = int(hparams.get("max_new_tokens", 256)) + remasking = hparams.get("remasking", "low_confidence") + + # W&B + init_wandb(hparams.get("_infer_cfg", {}), "s2t", ckpt_path, { + "steps": steps, + "block_length": block_length, + "max_new_tokens": max_new_tokens, + "remasking": remasking, + "batch_size": batch_size, + }) + + preds, refs, rows = [], [], [] + norm_mode = str(hparams.get("text_norm", "basic")) + normalize_fn = build_normalize_fn(norm_mode) + for batch in loader: + input_ids = batch["input_ids"].to(device) + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + with torch.no_grad(): + output_ids = model.mmu_generate( + input_ids, + max_new_tokens=max_new_tokens, + steps=steps, + block_length=block_length, + remasking=remasking, + ) + decoded = uni_prompting.text_tokenizer.batch_decode( + output_ids[:, input_ids.shape[1]:], skip_special_tokens=True + ) + # print(decoded) + clean_gts = [_strip_custom_markers(gt) for gt in gt_texts] + clean_preds = [_strip_custom_markers(pred) for pred in decoded] + print(clean_preds) + for sid, clean_gt, clean_pred in zip(sample_ids, clean_gts, clean_preds): + refs.append(clean_gt) + preds.append(clean_pred) + rows.append([sid, clean_gt, clean_pred]) + + wer, errors, words = calculate_wer(preds, refs, normalize=normalize_fn) + wandb.log({ + "metrics/s2t_wer": wer, + "metrics/s2t_word_errors": errors, + "metrics/s2t_total_words": words, + }) + safe_log_table("samples/s2t", ["ID", "GT", "PRED"], rows[:64]) + wandb.finish() + + +def main(): + parser = argparse.ArgumentParser(description="S2T Inference with CLI overrides or config grids") + parser.add_argument("--train_config", required=True, help="Path to training YAML used to build tokenizers and VQ models") + parser.add_argument("--ckpt_root", required=True, help="Experiment output dir or a specific checkpoint path") + parser.add_argument("--infer_config", required=False, help="Optional YAML for W&B and grids") + parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir") + + # Generation overrides + parser.add_argument("--steps", type=int) + parser.add_argument("--block_length", type=int) + parser.add_argument("--max_new_tokens", type=int) + parser.add_argument("--remasking") + parser.add_argument("--batch_size", type=int) + parser.add_argument("--text_norm", choices=["off", "basic", "english", "whisper", "whisper_en"], help="Text normalization for WER") + + # Dataset overrides + parser.add_argument("--subset") + parser.add_argument("--split") + parser.add_argument("--root_path") + parser.add_argument("--limit", type=int) + + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_cfg = load_train_config(args.train_config) + + infer_cfg = {} + if args.infer_config: + infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True) + + # Checkpoints + # Build checkpoint list: --checkpoint > infer_config.checkpoints > --ckpt_root + if args.checkpoint: + ckpt_list = [] + for p in args.checkpoint: + ckpt_list.extend(list_checkpoints(p)) + else: + ckpts = infer_cfg.get("checkpoints") if infer_cfg else None + if ckpts: + ckpt_list = [] + for p in ckpts: + ckpt_list.extend(list_checkpoints(p)) + else: + ckpt_list = list_checkpoints(args.ckpt_root) + if not ckpt_list: + raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.") + + override_present = any([ + args.steps is not None, args.block_length is not None, args.max_new_tokens is not None, + args.remasking is not None, args.batch_size is not None, + args.text_norm is not None, + args.subset is not None, args.split is not None, args.root_path is not None, args.limit is not None, + ]) + + if override_present or not infer_cfg: + single = { + "steps": args.steps if args.steps is not None else 128, + "block_length": args.block_length if args.block_length is not None else 64, + "max_new_tokens": args.max_new_tokens if args.max_new_tokens is not None else 256, + "remasking": args.remasking if args.remasking is not None else "low_confidence", + "batch_size": args.batch_size if args.batch_size is not None else int(train_cfg.training.batch_size_s2t), + } + if args.text_norm is not None: + single["text_norm"] = args.text_norm + dcfg = { + "subset": args.subset or "clean", + "split": args.split or "test", + "root_path": args.root_path or "/home/work/AIDAS/data/audio/LibriSpeech/test-clean", + "limit": args.limit if args.limit is not None else 128, + } + single["dataset"] = dcfg + single["_infer_cfg"] = infer_cfg + combos = [single] + else: + gen_grid = infer_cfg.get("generation", { + "steps": [128], + "block_length": [64], + "max_new_tokens": [256], + "remasking": ["low_confidence"], + "batch_size": [int(train_cfg.training.batch_size_s2t)], + }) + combos = grid_dict(gen_grid) + dcfg = infer_cfg.get("dataset", { + "subset": "clean", + "split": "test", + "root_path": "/home/work/AIDAS/data/audio/LibriSpeech/test-clean", + "limit": 128, + }) + # Apply overrides if provided + if args.subset is not None: + dcfg["subset"] = args.subset + if args.split is not None: + dcfg["split"] = args.split + if args.root_path is not None: + dcfg["root_path"] = args.root_path + if args.limit is not None: + dcfg["limit"] = args.limit + for c in combos: + if args.steps is not None: + c["steps"] = args.steps + if args.block_length is not None: + c["block_length"] = args.block_length + if args.max_new_tokens is not None: + c["max_new_tokens"] = args.max_new_tokens + if args.remasking is not None: + c["remasking"] = args.remasking + if args.batch_size is not None: + c["batch_size"] = args.batch_size + if args.text_norm is not None: + c["text_norm"] = args.text_norm + c["dataset"] = dcfg + c["_infer_cfg"] = infer_cfg + + for ckpt in ckpt_list: + for hp in combos: + run_once(ckpt, hp, train_cfg, device) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/inference/t2s_infer.py b/MMaDA/inference/t2s_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..f75fdf7b85d58d77cc55a7f0cc213417cc775ad9 --- /dev/null +++ b/MMaDA/inference/t2s_infer.py @@ -0,0 +1,476 @@ +import os +import argparse +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from typing import Callable, List +import re + +import torch +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset +import wandb + +from omegaconf import OmegaConf +from transformers import pipeline + +from training.data import T2S_INSTRUCTION +from inference.common import ( + load_train_config, + get_vq_model_audio, + build_uni_prompting, + load_omada_from_checkpoint, + list_checkpoints, + grid_dict, + init_wandb, + safe_log_table, +) +from models import get_mask_schedule + + +_ANGLE_TOKEN_RE = re.compile(r"<[^>]+>") +_EXCLAMATIONPOINT_RE = re.compile(r"exclamationpoint", flags=re.IGNORECASE) +_PUNCT_RE = re.compile(r"[^\w\s']") + + +def _strip_custom_markers(text: str) -> str: + had_exclamationpoint = bool(_EXCLAMATIONPOINT_RE.search(text)) + text = _ANGLE_TOKEN_RE.sub(" ", text) + if had_exclamationpoint: + text = _EXCLAMATIONPOINT_RE.sub(" ", text) + if had_exclamationpoint: + text = text.replace(".", "") + text = _PUNCT_RE.sub(" ", text) + text = re.sub(r"\s+", " ", text).strip() + return text + + +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + ex = self.hf_dataset[idx] + return {"gt_text": ex["text"], "sample_id": ex["id"]} + + +def ensure_dir(path: str): + os.makedirs(path, exist_ok=True) + + +def _basic_normalize(text: str) -> str: + text = _strip_custom_markers(text) + text = text.lower() + return text + + +def build_normalize_fn(mode: str) -> Callable[[str], str]: + mode = (mode or "basic").strip().lower() + if mode in {"off", "none", "no"}: + return lambda s: s + if mode in {"english", "whisper", "whisper_en"}: + try: + from normalizer.normalizer import EnglishTextNormalizer + + n = EnglishTextNormalizer() + + def _fn(s: str) -> str: + return re.sub(r"\s+", " ", n(s)).strip() + + return _fn + except Exception: + return _basic_normalize + return _basic_normalize + + +def calculate_wer(predictions: List[str], references: List[str], normalize: Callable[[str], str] = _basic_normalize): + import editdistance + # Normalize texts before WER + predictions = [normalize(p) for p in predictions] + references = [normalize(r) for r in references] + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pw = pred.split() + rw = ref.split() + total_errors += editdistance.eval(pw, rw) + total_words += len(rw) + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + + +def run_once(ckpt_path: str, hparams: dict, train_cfg, device): + uni_prompting, tokenizer = build_uni_prompting(train_cfg) + vq_audio = get_vq_model_audio(train_cfg, device) + model = load_omada_from_checkpoint(ckpt_path, device) + + # Dataset + dcfg = hparams.get("dataset", {}) + subset = dcfg.get("subset", "clean") + split = dcfg.get("split", "test") + limit = int(dcfg.get("limit", 32)) + ds_raw = load_dataset("librispeech_asr", subset, split=split) + if limit > 0: + ds_raw = ds_raw.select(range(min(limit, len(ds_raw)))) + ds = T2SEvalDataset(ds_raw) + batch_size = int(hparams.get("batch_size", train_cfg.training.batch_size_t2s)) + loader = DataLoader(ds, batch_size=batch_size, shuffle=False) + + # Generation params + mode = str(hparams.get("mode", "fixed")).lower() # 'fixed', 'free', or 'mmu' + guidance_scale = float(hparams.get("guidance_scale", train_cfg.training.guidance_scale)) + temperature = float(hparams.get("temperature", 1.0)) + timesteps = int(hparams.get("timesteps", 24 if mode != "mmu" else 256)) + default_seq = 254 if mode == "fixed" else (511 if mode == "mmu" else 255) + seq_len = int(hparams.get("seq_len", default_seq)) + block_length = int(hparams.get("block_length", 128)) + max_new_tokens = int(hparams.get("max_new_tokens", seq_len)) if seq_len > 0 else int(hparams.get("max_new_tokens", 512)) + audio_codebook_size = int(hparams.get("audio_codebook_size", 4096)) + noise_schedule = hparams.get("noise_schedule", train_cfg.training.get("mask_schedule", "cosine")) + # Convert string name to callable schedule function expected by model + noise_schedule_fn = get_mask_schedule(noise_schedule) if isinstance(noise_schedule, str) else noise_schedule + noise_type = hparams.get("noise_type", "mask") + + out_root = hparams.get("output_dir", os.path.join("outputs", "t2s")) + ensure_dir(out_root) + + # W&B + init_wandb(hparams.get("_infer_cfg", {}), "t2s", ckpt_path, { + "mode": mode, + "guidance_scale": guidance_scale, + "temperature": temperature, + "timesteps": timesteps, + "seq_len": seq_len, + "batch_size": batch_size, + }) + + mask_token_id = model.config.mask_token_id + rows = [] + + for batch in loader: + gt_texts: List[str] = batch["gt_text"] + clean_gt_texts = [_strip_custom_markers(text) for text in gt_texts] + sample_ids: List[str] = batch["sample_id"] + + # Build chat prompts + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{T2S_INSTRUCTION[0]}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in clean_gt_texts + ] + bsz = len(prompts) + audio_tokens = torch.ones((bsz, seq_len), dtype=torch.long, device=device) * mask_token_id + if mode == "fixed": + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen') + else: + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if guidance_scale and guidance_scale > 0 and mode != "mmu": + if mode == "fixed": + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * bsz, audio_tokens), 't2s_fixed_gen') + else: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * bsz, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + with torch.no_grad(): + if mode == "fixed": + outputs = model.t2s_fixed_generate( + input_ids=input_ids.to(device), + uncond_input_ids=None if uncond_input_ids is None else uncond_input_ids.to(device), + attention_mask=attention_mask.to(device), + uncond_attention_mask=None if uncond_attention_mask is None else uncond_attention_mask.to(device), + guidance_scale=guidance_scale, + temperature=temperature, + timesteps=timesteps, + noise_schedule=noise_schedule_fn, + noise_type=noise_type, + seq_len=seq_len, + uni_prompting=uni_prompting, + config=train_cfg, + ) + elif mode == "mmu": + outputs = model.t2s_generate_mmu_like( + input_ids=input_ids.to(device), + max_new_tokens=max_new_tokens, + steps=timesteps, + block_length=block_length, + temperature=temperature, + cfg_scale=guidance_scale, + mask_token_id=mask_token_id, + attention_mask=attention_mask.to(device), + uni_prompting=uni_prompting, + codebook_size=train_cfg.model.omada.codebook_size, + audio_codebook_size=audio_codebook_size, + ) + else: + outputs = model.t2s_generate( + input_ids=input_ids.to(device), + uncond_input_ids=None if uncond_input_ids is None else uncond_input_ids.to(device), + attention_mask=attention_mask.to(device), + uncond_attention_mask=None if uncond_attention_mask is None else uncond_attention_mask.to(device), + guidance_scale=guidance_scale, + temperature=temperature, + timesteps=timesteps, + noise_schedule=noise_schedule_fn, + noise_type=noise_type, + seq_len=seq_len, + uni_prompting=uni_prompting, + config=train_cfg, + ) + + # Decode each sample + for i in range(bsz): + if mode == "mmu": + gen_tokens = outputs[i] + if isinstance(gen_tokens, torch.Tensor): + rel_ids = gen_tokens.detach().cpu().tolist() + else: + rel_ids = list(gen_tokens) + else: + rel_ids = outputs[i].tolist() + if not rel_ids: + continue + unit_str = " ".join(map(str, rel_ids)) + speech_unit = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")]) + wav_name = f"{os.path.basename(os.path.dirname(ckpt_path))}_{sample_ids[i]}_{mode}.wav" + wav_path = os.path.join(out_root, wav_name) + _ = vq_audio.decode(speech_unit, condition='gender-female_emotion-neutral_speed-normal_pitch-normal', output_wav_file=wav_path) + rows.append([sample_ids[i], clean_gt_texts[i], wav_path]) + + # Log audio samples + aud_rows = [] + for sid, gt, wav in rows[:64]: + aud_rows.append([sid, gt, wandb.Audio(wav, caption=gt)]) + safe_log_table("samples/t2s", ["ID", "GT", "Audio"], aud_rows) + + # Optional WER evaluation via Whisper (or any ASR pipeline) + asr_model = hparams.get("wer_asr_model") + if asr_model: + try: + lang_in = hparams.get("wer_language", "english") + # Normalize language to avoid locale strings like C.UTF-8 + def _norm_lang(x: str) -> str: + if not isinstance(x, str) or not x: + return "english" + x = x.strip().lower() + if "utf" in x or x.startswith("c.") or x == "c": + return "english" + aliases = { + "en": "english", "eng": "english", "english": "english", + "ko": "korean", "kor": "korean", "korean": "korean", + "zh": "chinese", "cmn": "chinese", "chinese": "chinese", + "ja": "japanese", "jpn": "japanese", "japanese": "japanese", + } + return aliases.get(x, "english") + + lang = _norm_lang(lang_in) + max_samples = int(hparams.get("wer_max_samples", 1024)) + use_cuda = torch.cuda.is_available() + asr_pipe = pipeline("automatic-speech-recognition", model=asr_model, device=0 if use_cuda else -1) + + preds, refs = [], [] + norm_mode = str(hparams.get("text_norm", "basic")) + normalize_fn = build_normalize_fn(norm_mode) + trans_rows = [] + for i, (sid, gt, wav) in enumerate(rows): + if i >= max_samples: + break + try: + out = asr_pipe(wav, generate_kwargs={"language": lang, "task": "transcribe"}) + text = out.get("text", "") + except Exception: + text = "" + base_pred = _strip_custom_markers(text) + base_ref = _strip_custom_markers(gt) + preds.append(base_pred) + refs.append(base_ref) + if i < 32: + trans_rows.append([sid, base_ref, base_pred, wandb.Audio(wav, caption=base_pred)]) + + # Compute WER using normalized text + wer, errors, words = calculate_wer(preds, refs, normalize=normalize_fn) + wandb.log({ + "metrics/t2s_wer": wer, + "metrics/t2s_word_errors": errors, + "metrics/t2s_total_words": words, + }) + safe_log_table("samples/t2s_transcriptions", ["ID", "GT", "ASR", "Audio"], trans_rows) + except Exception as e: + wandb.log({"warn/t2s_wer_error": str(e)}) + + wandb.finish() + + +def main(): + parser = argparse.ArgumentParser(description="T2S Inference (fixed/free) with CLI overrides or config grids") + # Required basics + parser.add_argument("--train_config", required=True) + parser.add_argument("--ckpt_root", required=True, help="Experiment output dir or specific checkpoint path") + parser.add_argument("--infer_config", required=False, help="Optional YAML with wandb and/or grid configs") + parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir") + + # Optional generation overrides (single run when provided) + parser.add_argument("--mode", choices=["fixed", "free", "mmu"], help="T2S mode: fixed, free, or mmu") + parser.add_argument("--guidance_scale", type=float) + parser.add_argument("--temperature", type=float) + parser.add_argument("--timesteps", type=int) + parser.add_argument("--seq_len", type=int) + parser.add_argument("--block_length", type=int) + parser.add_argument("--max_new_tokens", type=int) + parser.add_argument("--noise_schedule") + parser.add_argument("--noise_type") + parser.add_argument("--batch_size", type=int) + parser.add_argument("--output_dir") + parser.add_argument("--text_norm", choices=["off", "basic", "english", "whisper", "whisper_en"], help="Text normalization for WER") + + # Optional dataset overrides + parser.add_argument("--subset") + parser.add_argument("--split") + parser.add_argument("--limit", type=int) + + # Optional WER logging via ASR + parser.add_argument("--wer_asr_model", help="HF model id for ASR, e.g., openai/whisper-large-v3") + parser.add_argument("--wer_language", help="Language hint for ASR generation") + parser.add_argument("--wer_max_samples", type=int, help="Max number of samples for WER computation") + parser.add_argument("--audio_codebook_size", type=int, help="Override audio codebook size for MMU mode") + + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_cfg = load_train_config(args.train_config) + + infer_cfg = {} + if args.infer_config: + infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True) + + # Checkpoints + # Build checkpoint list by priority: explicit --checkpoint > infer_config.checkpoints > --ckpt_root + if args.checkpoint: + ckpt_list = [] + for p in args.checkpoint: + ckpt_list.extend(list_checkpoints(p)) + else: + ckpts = infer_cfg.get("checkpoints") if infer_cfg else None + if ckpts: + ckpt_list = [] + for p in ckpts: + ckpt_list.extend(list_checkpoints(p)) + else: + ckpt_list = list_checkpoints(args.ckpt_root) + if not ckpt_list: + raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.") + + # Decide between single-run overrides or grid from config + override_present = any([ + args.mode is not None, args.guidance_scale is not None, args.temperature is not None, + args.timesteps is not None, args.seq_len is not None, args.noise_schedule is not None, + args.noise_type is not None, args.batch_size is not None, args.output_dir is not None, + args.block_length is not None, args.max_new_tokens is not None, + args.text_norm is not None, + args.subset is not None, args.split is not None, args.limit is not None, + ]) + + if override_present or not infer_cfg: + # Build single combination from CLI overrides with fallbacks + single = { + "mode": args.mode or "fixed", + "guidance_scale": args.guidance_scale if args.guidance_scale is not None else float(train_cfg.training.guidance_scale), + "temperature": args.temperature if args.temperature is not None else 1.0, + "timesteps": args.timesteps if args.timesteps is not None else 24, + "seq_len": args.seq_len if args.seq_len is not None else 254, + "batch_size": args.batch_size if args.batch_size is not None else int(train_cfg.training.batch_size_t2s), + "output_dir": args.output_dir or os.path.join("outputs", "t2s"), + "noise_schedule": args.noise_schedule if args.noise_schedule is not None else train_cfg.training.get("mask_schedule", "cosine"), + "noise_type": args.noise_type if args.noise_type is not None else "mask", + } + if args.text_norm is not None: + single["text_norm"] = args.text_norm + if args.block_length is not None: + single["block_length"] = args.block_length + if args.max_new_tokens is not None: + single["max_new_tokens"] = args.max_new_tokens + if args.audio_codebook_size is not None: + single["audio_codebook_size"] = args.audio_codebook_size + # WER options + if args.wer_asr_model is not None: + single["wer_asr_model"] = args.wer_asr_model + if args.wer_language is not None: + single["wer_language"] = args.wer_language + if args.wer_max_samples is not None: + single["wer_max_samples"] = args.wer_max_samples + dcfg = { + "subset": args.subset or "clean", + "split": args.split or "test", + "limit": args.limit if args.limit is not None else 32, + } + single["dataset"] = dcfg + single["_infer_cfg"] = infer_cfg + combos = [single] + else: + # Grid from config, allow CLI overrides to force values across the grid + gen_grid = infer_cfg.get("generation", { + "mode": ["fixed"], + "guidance_scale": [float(train_cfg.training.guidance_scale)], + "temperature": [1.0], + "timesteps": [24], + "seq_len": [254], + "batch_size": [int(train_cfg.training.batch_size_t2s)], + "output_dir": [os.path.join("outputs", "t2s")], + }) + combos = grid_dict(gen_grid) + dcfg = infer_cfg.get("dataset", { + "subset": "clean", + "split": "test", + "limit": 32, + }) + # Apply dataset overrides if given + if args.subset is not None: + dcfg["subset"] = args.subset + if args.split is not None: + dcfg["split"] = args.split + if args.limit is not None: + dcfg["limit"] = args.limit + # Apply generation overrides across combos if provided + for c in combos: + if args.mode is not None: + c["mode"] = args.mode + if args.guidance_scale is not None: + c["guidance_scale"] = args.guidance_scale + if args.temperature is not None: + c["temperature"] = args.temperature + if args.timesteps is not None: + c["timesteps"] = args.timesteps + if args.seq_len is not None: + c["seq_len"] = args.seq_len + if args.batch_size is not None: + c["batch_size"] = args.batch_size + if args.output_dir is not None: + c["output_dir"] = args.output_dir + if args.noise_schedule is not None: + c["noise_schedule"] = args.noise_schedule + if args.noise_type is not None: + c["noise_type"] = args.noise_type + if args.text_norm is not None: + c["text_norm"] = args.text_norm + if args.block_length is not None: + c["block_length"] = args.block_length + if args.max_new_tokens is not None: + c["max_new_tokens"] = args.max_new_tokens + if args.audio_codebook_size is not None: + c["audio_codebook_size"] = args.audio_codebook_size + if args.wer_asr_model is not None: + c["wer_asr_model"] = args.wer_asr_model + if args.wer_language is not None: + c["wer_language"] = args.wer_language + if args.wer_max_samples is not None: + c["wer_max_samples"] = args.wer_max_samples + c["dataset"] = dcfg + c["_infer_cfg"] = infer_cfg + + for ckpt in ckpt_list: + for hp in combos: + run_once(ckpt, hp, train_cfg, device) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/inference/upload.py b/MMaDA/inference/upload.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2fa7c295a4aca16c79c9ebc0afd2fc38c6f8bb --- /dev/null +++ b/MMaDA/inference/upload.py @@ -0,0 +1,9 @@ +from huggingface_hub import HfApi +import os + +api = HfApi(token=os.getenv("HF_TOKEN")) +api.upload_folder( + folder_path="/t1data/users/snu-lab-d/omada/ckpt/checkpoint-400000/unwrapped_model", + repo_id="jaeikkim/AIDAS-Omni-Modal-Diffusion", + repo_type="model", +) \ No newline at end of file diff --git a/MMaDA/inference/v2t_infer.py b/MMaDA/inference/v2t_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdb6c432937b1ef458940d33278882ad9cf0dbb --- /dev/null +++ b/MMaDA/inference/v2t_infer.py @@ -0,0 +1,201 @@ +import os +import argparse +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from typing import List + +import numpy as np +import cv2 +import torch +from PIL import Image +from torch.utils.data import DataLoader +import wandb + +from omegaconf import OmegaConf + +from training.utils import image_transform +from inference.common import ( + load_train_config, + get_vq_model_image, + build_uni_prompting, + load_omada_from_checkpoint, + list_checkpoints, + grid_dict, + init_wandb, + safe_log_table, +) + + +def sample_video_tokens(video_path: str, vq_model_image, uni_prompting, cfg, device) -> torch.Tensor: + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if total_frames <= 0: + cap.release() + raise RuntimeError(f"No frames in {video_path}") + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: + continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=cfg.dataset.preprocessing.resolution)) + cap.release() + if len(frames) < 8: + raise RuntimeError(f"Insufficient frames from {video_path}") + video_tensor = torch.stack(frames).to(device) + # offset by text tokenizer length as in training evaluation + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) + return video_tokens + + +def build_input_ids(video_tokens: torch.Tensor, question: str, uni_prompting, device) -> torch.Tensor: + spt = uni_prompting.sptids_dict + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + input_ids = torch.cat([ + spt['<|v2t|>'].to(device).unsqueeze(0), + spt['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + spt['<|eoi|>'].to(device).unsqueeze(0), + spt['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + return input_ids + + +def run_once(ckpt_path: str, hparams: dict, train_cfg, device): + uni_prompting, tokenizer = build_uni_prompting(train_cfg) + vq_img = get_vq_model_image(train_cfg, device) + model = load_omada_from_checkpoint(ckpt_path, device) + + video_dir = hparams.get("video_dir", "/home/work/AIDAS/video/demo") + questions = hparams.get("questions", ["Please provide a detailed description of the video."]) + steps = int(hparams.get("steps", 256)) + block_length = int(hparams.get("block_length", 128)) + max_new_tokens = int(hparams.get("max_new_tokens", 256)) + + # W&B + init_wandb(hparams.get("_infer_cfg", {}), "v2t", ckpt_path, { + "steps": steps, + "block_length": block_length, + "max_new_tokens": max_new_tokens, + }) + + files = [f for f in os.listdir(video_dir) if f.lower().endswith(".mp4")] + files.sort() + rows = [] + for fname in files: + vpath = os.path.join(video_dir, fname) + try: + vtoks = sample_video_tokens(vpath, vq_img, uni_prompting, train_cfg, device) + except Exception: + continue + for q in questions: + inp = build_input_ids(vtoks, q, uni_prompting, device) + with torch.no_grad(): + out_ids = model.mmu_generate( + inp, + max_new_tokens=max_new_tokens, + steps=steps, + block_length=block_length, + ) + text = uni_prompting.text_tokenizer.batch_decode( + out_ids[:, inp.shape[1]:], skip_special_tokens=True + )[0] + rows.append([fname, q, text]) + + safe_log_table("samples/v2t", ["Video", "Question", "Caption"], rows) + wandb.finish() + + +def main(): + parser = argparse.ArgumentParser(description="V2T Inference with CLI overrides or config grids") + parser.add_argument("--train_config", required=True) + parser.add_argument("--ckpt_root", required=True) + parser.add_argument("--infer_config", required=False) + parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir") + + # Generation overrides + parser.add_argument("--steps", type=int) + parser.add_argument("--block_length", type=int) + parser.add_argument("--max_new_tokens", type=int) + + # Dataset overrides + parser.add_argument("--video_dir") + parser.add_argument("--question", action="append", help="Repeatable: --question 'text' --question 'another'") + + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_cfg = load_train_config(args.train_config) + + infer_cfg = {} + if args.infer_config: + infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True) + + if args.checkpoint: + ckpt_list = [] + for p in args.checkpoint: + ckpt_list.extend(list_checkpoints(p)) + else: + ckpts = infer_cfg.get("checkpoints") if infer_cfg else None + if ckpts: + ckpt_list = [] + for p in ckpts: + ckpt_list.extend(list_checkpoints(p)) + else: + ckpt_list = list_checkpoints(args.ckpt_root) + if not ckpt_list: + raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.") + + override_present = any([ + args.steps is not None, args.block_length is not None, args.max_new_tokens is not None, + args.video_dir is not None, args.question is not None, + ]) + + if override_present or not infer_cfg: + single = { + "steps": args.steps if args.steps is not None else 256, + "block_length": args.block_length if args.block_length is not None else 128, + "max_new_tokens": args.max_new_tokens if args.max_new_tokens is not None else 256, + "video_dir": args.video_dir or "/home/work/AIDAS/video/demo", + "questions": args.question or ["Please provide a detailed description of the video."], + } + single["_infer_cfg"] = infer_cfg + combos = [single] + else: + gen_grid = infer_cfg.get("generation", { + "steps": [256], + "block_length": [128], + "max_new_tokens": [256], + }) + combos = grid_dict(gen_grid) + dcfg = infer_cfg.get("dataset", { + "video_dir": "/home/work/AIDAS/video/demo", + "questions": ["Please provide a detailed description of the video."], + }) + if args.video_dir is not None: + dcfg["video_dir"] = args.video_dir + if args.question is not None: + dcfg["questions"] = args.question + for c in combos: + if args.steps is not None: + c["steps"] = args.steps + if args.block_length is not None: + c["block_length"] = args.block_length + if args.max_new_tokens is not None: + c["max_new_tokens"] = args.max_new_tokens + c.update(dcfg) + c["_infer_cfg"] = infer_cfg + + for ckpt in ckpt_list: + for hp in combos: + run_once(ckpt, hp, train_cfg, device) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/inference_emova.py b/MMaDA/inference_emova.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf1243c55b9157861972084fedac39cce10c797 --- /dev/null +++ b/MMaDA/inference_emova.py @@ -0,0 +1,250 @@ +import sys +import json +import torch +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +import soundfile as sf +import re + +input_json_path = "/home/work/AIDAS/t2s_logs/librispeech_result.json" +output_dir = "/home/work/AIDAS/t2s_logs/decoded_wav/" + +import os +os.makedirs(output_dir, exist_ok=True) + +device = "cuda" if torch.cuda.is_available() else "cpu" +vq_model = EMOVASpeechTokenizer.from_pretrained("Emova-ollm/emova_speech_tokenizer_hf").to(device) +vq_model.eval() + +# tokens = vq_model.encode("/home/work/AIDAS/data/audio/commonvoice/cv-corpus-22.0-2025-06-20/en/clips/common_voice_en_619035.mp3") + +# numbers = tokens.tolist()[0] + +# print(numbers) + +""" +ģ•„ź¹Œ 135080 ė‚˜ģ˜Ø ė’¤ ģ‹œķ€€ģŠ¤ė“¤ 다 ėŠź²¼ģ—ˆėŠ”ė° +ė‹¤ģ‹œ ķ•˜ė‹ˆź¹Œ 또 ģ•ˆėŠź²Øģš” ;;;;;; +""" + +#### ORIGINAL +numbers = [135080] * 50 + +# 12쓈 짜리 + +# (NOT SUCH A SERIOUS PROBLEM IN NEVADA BUT IN OTHER PARTS OF THE COUNTRY WHERE +# WHERE GOVERNORS AND STATE LEGISLATORS HAVE OPTED OUT OF THE AFFORDABLE CARE ACT AND +# SO MANY OF THEIR OWN PEOPLE ARE LEFT BEHIND.") + +# 135080ģ—ģ„œ ėŠź¹€ (다 ģŒģ†Œ 토큰) +# 또 ėŠźø°ė‚˜ ķ™•ģø +numbers = [135081, 135080, 135081, 135666, 136883, 138099, 134597, 134768, 135402, + 137951, 135928, 135038, 135109, 135436, 135484, 136375, 138057, 136523, + 135883, 135042, 135038, 138177, 134603, 134599, 137728, 138203, 137115, + 138121, 135563, 135039, 135426, 136773, 136133, 136197, 135677, 134903, + 134843, 135244, 137737, 136583, 137606, 138057, 136588, 137610, 138427, + 138020, 137003, 137458, 138032, 136750, 135282, 137416, 136743, 135329, + 136875, 137832, 137401, 136387, 138248, 137990, 135985, 138462, 138364, + 138012, 137068, 135674, 135688, 134792, 135436, 138566, 137349, 136267, + 136259, 135802, 137285, 135621, 134662, 134661, 134654, 135738, 137440, + 138005, 135029, 135174, 134792, 137416, 138558, 136378, 136443, 135149, + 136133, 136068, 137076, 136500, 138413, 136125, 136253, 138345, 138132, + 136620, 135675, 135674, 135600, 135080, 135081, 135666, 135665, 137653, + 138309, 134650, 135102, 136760, 136761, 136250, 135665, 135600, 135592, + 135080, 135081, 135593, 135666, 136178, 135601, 137661, 138245, 134650, + 134591, 135102, 136761, 136250, 135666, 135602, 135666, 136308, 136639, + 135622, 134794, 135241, 137415, 136263, 137403, 138099, 136262, 136267, + 136259, 135747, 134784, 134839, 135738, 135674, 135602, 136178, 138498, + 138636, 138628, 138500, 138363, 136250, 135610, 135602, 136186, 136379, + 135040, 135535, 138285, 137271, 134676, 137950, 136312, 137095, 137413, + 135332, 137346, 137409, 136972, 135041, 135030, 137095, 138622, 135124, + 135842, 138349, 136264, 136779, 134838, 135224, 136589, 136085, 134721, + 136391, 136833, 136250, 136177, 136754, 135725, 134904, 135355, 135428, + 136838, 136880, 137951, 138351, 137985, 136892, 138346, 135870, 134614, + 134623, 134624, 135210, 135283, 136938, 138398, 136250, 135601, 135080, + 135592, 135658, 136178, 136753, 135238, 136456, 138559, 138629, 138496, + 136250, 138056, 135883, 135365, 136125, 134573, 134581, 137275, 138346, + 138055, 136908, 138248, 138120, 136073, 137026, 135205, 136702, 134714, + 134584, 134591, 135103, 136191, 136761, 137519, 135339, 135868, 135411, + 135775, 136865, 136817, 135601, 135592, 135080, 135081, 135666, 136690, + 138634, 138628, 138355, 136754, 135993, 135542, 135085, 134571, 135091, + 135667, 136379, 137611, 136262, 134892, 136028, 137083, 137827, 137108, + 136178, 135174, 134792, 137416, 137473, 138565, 138309, 134592, 135614, + 136185, 136186, 135086, 134571, 135164, 137588, 135803, 135739, 135875, + 137285, 138151, 136596, 135068, 134772, 134916, 137797, 138119, 136073, + 136186, 135166, 134661, 136248, 135737, 137095, 136325, 134700, 134659, + 134917, 136430, 138518, 137265, 138248, 138314, 136188, 136598, 137133, + 134552, 134546, 134547, 134555, 136420, 137580, 138108, 138436, 138354, + 136762, 135610, 135081, 135080] + +# numbers = [134552] * 50 +### MID POSITIONED-FUCKING 135080 REMOVED + +# numbers = [135081, 135080, 135081, 135666, 136883, 138099, 134597, 134768, 135402, +# 137951, 135928, 135038, 135109, 135436, 135484, 136375, 138057, 136523, +# 135883, 135042, 135038, 138177, 134603, 134599, 137728, 138203, 137115, +# 138121, 135563, 135039, 135426, 136773, 136133, 136197, 135677, 134903, +# 134843, 135244, 137737, 136583, 137606, 138057, 136588, 137610, 138427, +# 138020, 137003, 137458, 138032, 136750, 135282, 137416, 136743, 135329, +# 136875, 137832, 137401, 136387, 138248, 137990, 135985, 138462, 138364, +# 138012, 137068, 135674, 135688, 134792, 135436, 138566, 137349, 136267, +# 136259, 135802, 137285, 135621, 134662, 134661, 134654, 135738, 137440, +# 138005, 135029, 135174, 134792, 137416, 138558, 136378, 136443, 135149, +# 136133, 136068, 137076, 136500, 138413, 136125, 136253, 138345, 138132, +# 136620, 135675, 135674, 135600, 135081, 135666, 135665, 137653, +# 138309, 134650, 135102, 136760, 136761, 136250, 135665, 135600, 135592, +# 135081, 135593, 135666, 136178, 135601, 137661, 138245, 134650, +# 134591, 135102, 136761, 136250, 135666, 135602, 135666, 136308, 136639, +# 135622, 134794, 135241, 137415, 136263, 137403, 138099, 136262, 136267, +# 136259, 135747, 134784, 134839, 135738, 135674, 135602, 136178, 138498, +# 138636, 138628, 138500, 138363, 136250, 135610, 135602, 136186, 136379, +# 135040, 135535, 138285, 137271, 134676, 137950, 136312, 137095, 137413, +# 135332, 137346, 137409, 136972, 135041, 135030, 137095, 138622, 135124, +# 135842, 138349, 136264, 136779, 134838, 135224, 136589, 136085, 134721, +# 136391, 136833, 136250, 136177, 136754, 135725, 134904, 135355, 135428, +# 136838, 136880, 137951, 138351, 137985, 136892, 138346, 135870, 134614, +# 134623, 134624, 135210, 135283, 136938, 138398, 136250, 135601, +# 135592, 135658, 136178, 136753, 135238, 136456, 138559, 138629, 138496, +# 136250, 138056, 135883, 135365, 136125, 134573, 134581, 137275, 138346, +# 138055, 136908, 138248, 138120, 136073, 137026, 135205, 136702, 134714, +# 134584, 134591, 135103, 136191, 136761, 137519, 135339, 135868, 135411, +# 135775, 136865, 136817, 135601, 135592, 135081, 135666, 136690, +# 138634, 138628, 138355, 136754, 135993, 135542, 135085, 134571, 135091, +# 135667, 136379, 137611, 136262, 134892, 136028, 137083, 137827, 137108, +# 136178, 135174, 134792, 137416, 137473, 138565, 138309, 134592, 135614, +# 136185, 136186, 135086, 134571, 135164, 137588, 135803, 135739, 135875, +# 137285, 138151, 136596, 135068, 134772, 134916, 137797, 138119, 136073, +# 136186, 135166, 134661, 136248, 135737, 137095, 136325, 134700, 134659, +# 134917, 136430, 138518, 137265, 138248, 138314, 136188, 136598, 137133, +# 134552, 134546, 134547, 134555, 136420, 137580, 138108, 138436, 138354, +# 136762, 135610, 135081, 135080] + +# ### ALL-FUCKING 135080 / 135081 REMOVED +# numbers = [ 135666, 136883, 138099, 134597, 134768, 135402, +# 137951, 135928, 135038, 135109, 135436, 135484, 136375, 138057, 136523, +# 135883, 135042, 135038, 138177, 134603, 134599, 137728, 138203, 137115, +# 138121, 135563, 135039, 135426, 136773, 136133, 136197, 135677, 134903, +# 134843, 135244, 137737, 136583, 137606, 138057, 136588, 137610, 138427, +# 138020, 137003, 137458, 138032, 136750, 135282, 137416, 136743, 135329, +# 136875, 137832, 137401, 136387, 138248, 137990, 135985, 138462, 138364, +# 138012, 137068, 135674, 135688, 134792, 135436, 138566, 137349, 136267, +# 136259, 135802, 137285, 135621, 134662, 134661, 134654, 135738, 137440, +# 138005, 135029, 135174, 134792, 137416, 138558, 136378, 136443, 135149, +# 136133, 136068, 137076, 136500, 138413, 136125, 136253, 138345, 138132, +# 136620, 135675, 135674, 135600, 135666, 135665, 137653, +# 138309, 134650, 135102, 136760, 136761, 136250, 135665, 135600, 135592, +# 135593, 135666, 136178, 135601, 137661, 138245, 134650, +# 134591, 135102, 136761, 136250, 135666, 135602, 135666, 136308, 136639, +# 135622, 134794, 135241, 137415, 136263, 137403, 138099, 136262, 136267, +# 136259, 135747, 134784, 134839, 135738, 135674, 135602, 136178, 138498, +# 138636, 138628, 138500, 138363, 136250, 135610, 135602, 136186, 136379, +# 135040, 135535, 138285, 137271, 134676, 137950, 136312, 137095, 137413, +# 135332, 137346, 137409, 136972, 135041, 135030, 137095, 138622, 135124, +# 135842, 138349, 136264, 136779, 134838, 135224, 136589, 136085, 134721, +# 136391, 136833, 136250, 136177, 136754, 135725, 134904, 135355, 135428, +# 136838, 136880, 137951, 138351, 137985, 136892, 138346, 135870, 134614, +# 134623, 134624, 135210, 135283, 136938, 138398, 136250, 135601, +# 135592, 135658, 136178, 136753, 135238, 136456, 138559, 138629, 138496, +# 136250, 138056, 135883, 135365, 136125, 134573, 134581, 137275, 138346, +# 138055, 136908, 138248, 138120, 136073, 137026, 135205, 136702, 134714, +# 134584, 134591, 135103, 136191, 136761, 137519, 135339, 135868, 135411, +# 135775, 136865, 136817, 135601, 135592, 135666, 136690, +# 138634, 138628, 138355, 136754, 135993, 135542, 135085, 134571, 135091, +# 135667, 136379, 137611, 136262, 134892, 136028, 137083, 137827, 137108, +# 136178, 135174, 134792, 137416, 137473, 138565, 138309, 134592, 135614, +# 136185, 136186, 135086, 134571, 135164, 137588, 135803, 135739, 135875, +# 137285, 138151, 136596, 135068, 134772, 134916, 137797, 138119, 136073, +# 136186, 135166, 134661, 136248, 135737, 137095, 136325, 134700, 134659, +# 134917, 136430, 138518, 137265, 138248, 138314, 136188, 136598, 137133, +# 134552, 134546, 134547, 134555, 136420, 137580, 138108, 138436, 138354, +# 136762, 135610, ] + +### ź°•ģ œ ģ£¼ģž… 뭐지 ģ™œ ģ•ˆė©ˆģ¶”ģ§€? +# numbers = [ 135666, 136883, 138099, 134597, 134768, 135402, +# 137951, 135928, 135038, 135109, 135436, 135484, 136375, 138057, 136523, +# 135883, 135042, 135038, 138177, 134603, 134599, 137728, 138203, 137115, +# 138121, 135563, 135039, 135426, 136773, 136133, 136197, 135677, 134903, +# 134843, 135244, 137737, 136583, 137606, 138057, 136588, 137610, 138427, +# 138020, 137003, 137458, 138032, 136750, 135282, 137416, 136743, 135329, +# 136875, 137832, 137401, 136387, 138248, 137990, 135985, 138462, 138364, +# 138012, 137068, 135674, 135688, 134792, 135436, 138566, 137349, 136267, +# 136259, 135802, 137285, 135621, 134662, 134661, 134654, 135080, 135081, 135738, 137440, +# 138005, 135029, 135174, 134792, 137416, 138558, 136378, 136443, 135149, +# 136133, 136068, 137076, 136500, 138413, 136125, 136253, 138345, 138132, +# 136620, 135675, 135674, 135600, 135666, 135665, 137653, +# 138309, 134650, 135102, 136760, 136761, 136250, 135665, 135600, 135592, +# 135593, 135666, 136178, 135601, 137661, 138245, 134650, +# 134591, 135102, 136761, 136250, 135666, 135602, 135666, 136308, 136639, +# 135622, 134794, 135241, 137415, 136263, 137403, 138099, 136262, 136267, +# 136259, 135747, 134784, 134839, 135738, 135674, 135602, 136178, 138498, +# 138636, 138628, 138500, 138363, 136250, 135610, 135602, 136186, 136379, +# 135040, 135535, 138285, 137271, 134676, 137950, 136312, 137095, 137413, +# 135332, 137346, 137409, 136972, 135041, 135030, 137095, 138622, 135124, +# 135842, 138349, 136264, 136779, 134838, 135224, 136589, 136085, 134721, +# 136391, 136833, 136250, 136177, 136754, 135725, 134904, 135355, 135428, +# 136838, 136880, 137951, 138351, 137985, 136892, 138346, 135870, 134614, +# 134623, 134624, 135210, 135283, 136938, 138398, 136250, 135601, +# 135592, 135658, 136178, 136753, 135238, 136456, 138559, 138629, 138496, +# 136250, 138056, 135883, 135365, 136125, 134573, 134581, 137275, 138346, +# 138055, 136908, 138248, 138120, 136073, 137026, 135205, 136702, 134714, +# 134584, 134591, 135103, 136191, 136761, 137519, 135339, 135868, 135411, +# 135775, 136865, 136817, 135601, 135592, 135666, 136690, +# 138634, 138628, 138355, 136754, 135993, 135542, 135085, 134571, 135091, +# 135667, 136379, 137611, 136262, 134892, 136028, 137083, 137827, 137108, +# 136178, 135174, 134792, 137416, 137473, 138565, 138309, 134592, 135614, +# 136185, 136186, 135086, 134571, 135164, 137588, 135803, 135739, 135875, +# 137285, 138151, 136596, 135068, 134772, 134916, 137797, 138119, 136073, +# 136186, 135166, 134661, 136248, 135737, 137095, 136325, 134700, 134659, +# 134917, 136430, 138518, 137265, 138248, 138314, 136188, 136598, 137133, +# 134552, 134546, 134547, 134555, 136420, 137580, 138108, 138436, 138354, +# 136762, 135610, ] + +# TEST +# numbers = [135081, 135080, 135081, 135666, 136883, 138099, 134597, 134768, 135402, +# 137951, 135928, 135038, 135109, 135436, 135484, 136375, 138057, 136523, +# 135883, 135042, 135038, 138177, 135061, 135061, 135061, 135061, 135061, +# 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, +# 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, +# 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, +# 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, +# 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, +# 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, +# 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, 135061, +# 135061, 135061, 135061, 135061, 135061, +# 135592, 135658, 136178, 136753, 135238, 136456, 138559, 138629, 138496, +# 136250, 138056, 135883, 135365, 136125, 134573, 134581, 137275, 138346, +# 138055, 136908, 138248, 138120, 136073, 137026, 135205, 136702, 134714, +# 134584, 134591, 135103, 136191, 136761, 137519, 135339, 135868, 135411, +# 135775, 136865, 136817, 135601, 135592, 135080, 135081, 135666, 136690, +# 138634, 138628, 138355, 136754, 135993, 135542, 135085, 134571, 135091, +# 135667, 136379, 137611, 136262, 134892, 136028, 137083, 137827, 137108, +# 136178, 135174, 134792, 137416, 137473, 138565, 138309, 134592, 135614, +# 136185, 136186, 135086, 134571, 135164, 137588, 135803, 135739, 135875, +# 137285, 138151, 136596, 135068, 134772, 134916, 137797, 138119, 136073, +# 136186, 135166, 134661, 136248, 135737, 137095, 136325, 134700, 134659, +# 134917, 136430, 138518, 137265, 138248, 138314, 136188, 136598, 137133, +# 134552, 134546, 134547, 134555, 136420, 137580, 138108, 138436, 138354, +# 136762, 135610, 135081, 135080] + +print(numbers) +# 'NOT SUCH A SERIOUS PROBLEM IN NEVADA BUT IN OTHER PARTS OF THE COUNTRY WHERE WHERE GOVERNORS AND STATE LEGISLATORS HAVE OPTED OUT OF THE AFFORDABLE CARE ACT AND SO MANY OF THEIR OWN PEOPLE ARE LEFT BEHIND. +offset = 134541 + +# FUCKING 135080 == 539 in EMOVA +speech_tokens = [f"<|speech_{n - offset}|>" for n in numbers] + +# speech_tokens = [f"<|speech_{n}|>" for n in numbers] + +token_str = "".join(speech_tokens) + +# token_str = "<|speech_540|><|speech_539|><|speech_1053|><|speech_1068|><|speech_3121|><|speech_3256|><|speech_1479|><|speech_3055|><|speech_4088|><|speech_4027|><|speech_1653|><|speech_1049|><|speech_38|><|speech_2045|><|speech_3814|><|speech_2156|><|speech_3121|><|speech_2600|><|speech_2807|><|speech_3582|><|speech_3579|><|speech_3567|><|speech_1902|><|speech_1545|><|speech_2048|><|speech_1246|><|speech_3750|><|speech_2220|><|speech_2042|><|speech_1784|><|speech_95|><|speech_376|><|speech_2449|><|speech_3921|><|speech_4024|><|speech_3954|><|speech_569|><|speech_2593|><|speech_2079|><|speech_1583|><|speech_1911|><|speech_3071|><|speech_2430|><|speech_1197|><|speech_1702|><|speech_2559|><|speech_2811|><|speech_1651|><|speech_166|><|speech_1531|><|speech_1405|><|speech_3197|><|speech_3834|><|speech_2290|><|speech_1138|><|speech_3733|><|speech_3087|><|speech_2287|><|speech_3808|><|speech_568|><|speech_566|><|speech_1311|><|speech_3551|><|speech_3951|><|speech_3758|><|speech_2852|><|speech_3624|><|speech_2579|><|speech_2569|><|speech_3609|><|speech_371|><|speech_1774|><|speech_3694|><|speech_1576|><|speech_24|><|speech_26|><|speech_860|><|speech_3856|><|speech_1721|><|speech_2238|><|speech_4094|><|speech_4079|><|speech_3959|><|speech_3758|><|speech_1451|><|speech_497|><|speech_761|><|speech_231|><|speech_167|><|speech_2791|><|speech_3822|><|speech_2222|><|speech_1141|><|speech_548|><|speech_539|><|speech_540|><|speech_1052|><|speech_1125|><|speech_1701|><|speech_1444|><|speech_1001|><|speech_1056|><|speech_533|><|speech_614|><|speech_1126|><|speech_1254|><|speech_2215|><|speech_2097|><|speech_1145|><|speech_2721|><|speech_71|><|speech_591|><|speech_1004|><|speech_497|><|speech_1325|><|speech_1646|><|speech_1061|><|speech_540|><|speech_539|><|speech_1052|><|speech_1053|><|speech_2589|><|speech_3611|><|speech_2578|><|speech_1068|><|speech_1787|><|speech_2291|><|speech_2042|><|speech_544|><|speech_17|><|speech_995|><|speech_986|><|speech_3786|><|speech_3612|><|speech_552|><|speech_609|><|speech_3558|><|speech_2859|><|speech_2065|><|speech_1575|><|speech_1645|><|speech_824|><|speech_1144|><|speech_2800|><|speech_350|><|speech_2007|><|speech_3559|><|speech_3887|><|speech_3813|><|speech_2221|><|speech_1133|><|speech_549|><|speech_540|>" +print(token_str) + +with torch.no_grad(): + output_wav_path = f"/home/work/AIDAS/emova_outputs/gt_test.wav" + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + vq_model.decode( + token_str, + condition=condition, + output_wav_file=output_wav_path + ) diff --git a/MMaDA/inference_i2i.py b/MMaDA/inference_i2i.py new file mode 100644 index 0000000000000000000000000000000000000000..1e244048c0de4a0708cd1c7ae758274fbd90e359 --- /dev/null +++ b/MMaDA/inference_i2i.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2025 MMaDA Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import inspect +import sys + +os.environ["TOKENIZERS_PARALLELISM"] = "true" +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +import wandb +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf, image_transform +from transformers import AutoTokenizer, AutoConfig, AutoModel +import torch.nn.functional as F + +import argparse +from datasets import load_dataset + +def resize_vocab(model, config): + print(f"Resizing token embeddings to {config.new_vocab_size}") + model.resize_token_embeddings(config.new_vocab_size) + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + else: + raise ValueError(f"model_type {model_type} not supported.") + +if __name__ == '__main__': + + config = get_config() + + + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + + wandb.init( + project="inference", + name=config.experiment.name + '_i2i', + config=wandb_config, + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + + + model.to(device) + + mask_token_id = model.config.mask_token_id + + # image_preprocessor = image_transform(resolution=config.dataset.params.resolution, is_train=False) + + # print(f"Loading dataset directly from: {config.hf_data_dir}") + + dataset = load_dataset("timbrooks/instructpix2pix-clip-filtered", split="train") + + # dataset = dataset.select(range(config.num_samples)) + + print(f"Processing {len(dataset)} samples.") + + # config.training.batch_size = config.batch_size + config.training.batch_size = 1 + config.training.guidance_scale = config.guidance_scale + config.training.generation_timesteps = config.generation_timesteps + + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + for step, sample in enumerate(tqdm(dataset)): + + input_image = sample['original_image'].convert("RGB") + prompts = [sample['edit_prompt']] + + image = image_transform(input_image,resolution=config.dataset.params.resolution).unsqueeze(0).to(device) + + input_image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + + output_image_placeholder = torch.ones((len(prompts), config.model.mmada.num_vq_tokens), + dtype=torch.long, device=device) * mask_token_id + + input_ids, attention_mask = uni_prompting( + (prompts, input_image_tokens, output_image_placeholder), 'i2i_gen' + ) + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(prompts), input_image_tokens, + output_image_placeholder), 'i2i_gen') + else: + uncond_input_ids = None + uncond_attention_mask = None + + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + with torch.no_grad(): + gen_token_ids = model.i2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=config.training.guidance_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + + gen_token_ids = torch.clamp(gen_token_ids, max=config.model.mmada.codebook_size - 1, min=0) + + images = vq_model.decode_code(gen_token_ids) + + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + wandb_images = [] + pil_images = [Image.fromarray(image) for image in images] + + generated_image = pil_images[0] + original_image = input_image + + caption = f"Prompt: {prompts}" + wandb.log({ + "Image-to-Image Results": [ + wandb.Image(original_image, caption=f"Original (Step {step})"), + wandb.Image(generated_image, caption=f"Edited - {caption} (Step {step})") + ] + }, step=step) \ No newline at end of file diff --git a/MMaDA/inference_mmu.py b/MMaDA/inference_mmu.py new file mode 100644 index 0000000000000000000000000000000000000000..a53edabeaf503457ba2cc78fa9e41a39242c388a --- /dev/null +++ b/MMaDA/inference_mmu.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2025 MMaDA Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +os.environ["TOKENIZERS_PARALLELISM"] = "true" +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +import wandb +from models import MAGVITv2, MMadaConfig, MMadaModelLM +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf, image_transform +from transformers import AutoTokenizer, AutoConfig + +def resize_vocab(model, config): + print(f"Resizing token embeddings to {config.new_vocab_size}") + model.resize_token_embeddings(config.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + else: + raise ValueError(f"model_type {model_type} not supported.") + +if __name__ == '__main__': + + config = get_config() + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + + wandb.init( + project="demo", + name=config.experiment.name + '_mmu', + config=wandb_config, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + model.to(device) + + mask_token_id = model.config.mask_token_id + + temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions + top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability + file_list = os.listdir(config.mmu_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + responses = ['' for i in range(len(file_list))] + images = [] + config.question = config.question.split(' *** ') + for i, file_name in enumerate(tqdm(file_list)): + image_path = os.path.join(config.mmu_image_root, file_name) + image_ori = Image.open(image_path).convert("RGB") + image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) + image = image.unsqueeze(0) + images.append(image) + # Segmentation Fault + image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + batch_size = 2 + + for question in config.question: + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." +'<|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + image_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + output_ids = model.mmu_generate(input_ids, max_new_tokens=1024, steps=512, block_length=1024) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + print(text) + responses[i] += f'User: ' + question + f'\n Answer : ' + text[0] + '\n' + + images = torch.cat(images, dim=0) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] + wandb.log({"multimodal understanding": wandb_images}, step=0) + diff --git a/MMaDA/inference_mmu_ori.py b/MMaDA/inference_mmu_ori.py new file mode 100644 index 0000000000000000000000000000000000000000..466424001fd8a9d3016e7c52c511de721f0d304e --- /dev/null +++ b/MMaDA/inference_mmu_ori.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2025 MMaDA Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ["TOKENIZERS_PARALLELISM"] = "true" +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +import wandb +from models import MAGVITv2, MMadaConfig, MMadaModelLM +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf, image_transform +from transformers import AutoTokenizer, AutoConfig + +def resize_vocab(model, config): + print(f"Resizing token embeddings to {config.new_vocab_size}") + model.resize_token_embeddings(config.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + else: + raise ValueError(f"model_type {model_type} not supported.") + +if __name__ == '__main__': + + config = get_config() + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + + wandb.init( + project="demo", + name=config.experiment.name + '_mmu', + config=wandb_config, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + model.to(device) + + mask_token_id = model.config.mask_token_id + + temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions + top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability + file_list = os.listdir(config.mmu_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + responses = ['' for i in range(len(file_list))] + images = [] + config.question = config.question.split(' *** ') + for i, file_name in enumerate(tqdm(file_list)): + image_path = os.path.join(config.mmu_image_root, file_name) + image_ori = Image.open(image_path).convert("RGB") + image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) + image = image.unsqueeze(0) + images.append(image) + image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + print(f"image tokens shape: {image_tokens.shape}") + batch_size = 1 + + for question in config.question: + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." +'<|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + image_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + print(f"input_ids shape: {input_ids.shape}") + output_ids = model.mmu_generate(input_ids, max_new_tokens=1024, steps=512, block_length=1024) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + print(text) + responses[i] += f'User: ' + question + f'\n Answer : ' + text[0] + '\n' + + images = torch.cat(images, dim=0) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] + wandb.log({"multimodal understanding": wandb_images}, step=0) diff --git a/MMaDA/inference_s2t.py b/MMaDA/inference_s2t.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6eb704bee18fcf12bd0a645653b5755b3f4db4 --- /dev/null +++ b/MMaDA/inference_s2t.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2025 AIDAS Lab +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +os.environ["TOKENIZERS_PARALLELISM"] = "true" +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch.nn.functional as F + +import torch +import wandb +from models import MMadaModelLM +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf +from transformers import AutoTokenizer +import argparse + +# from models.modeling_speech_tokenizer import EMOVASpeechTokenizer + +def resize_vocab(model, config): + print(f"Resizing token embeddings to {config.model.mmada.new_vocab_size}") + model.resize_token_embeddings(config.model.mmada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +if __name__ == '__main__': + + config = get_config() + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + + wandb.init( + project="demo", + name=config.experiment.name + '_stt', + config=wandb_config, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + text_tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model = get_vq_model_class(config.model.speech_model.type) + vq_model = vq_model.from_pretrained(config.model.speech_model.speech_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + quantizer = vq_model.encoder.quantizer + + if hasattr(quantizer, 'codebook_size'): + print("Codebook size:", quantizer.codebook_size) + + # 2) codebook ģž„ė² ė”© ė§¤ķŠøė¦­ģŠ¤ė”œė¶€ķ„° shape ģ¶”ģ¶œ + elif hasattr(quantizer, 'codebook'): + cb = quantizer.codebook # nn.Embedding ķ˜•ķƒœģ¼ ź°€ėŠ„ģ„± + print("Codebook size:", cb.weight.shape[0]) + + # 3) FSQģø 경우 levels 딜 ģ–‘ģžķ™” 단계 수 ķ™•ģø + elif hasattr(quantizer, 'levels'): + levels = quantizer.levels + print("Quantization levels per group:", levels) + print("Total scalar bins:", sum(levels)) + else: + raise RuntimeError("Quantizer에 codebook 정볓가 ģ—†ģŠµė‹ˆė‹¤.") + + sys.exit() + # model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + + # c) Load main MMaDA model + # train_step = config.model.mmada.train_step + trained_checkpoint_path = f"/home/work/AIDAS/omada-training-stage1/checkpoint-10000/unwrapped_model/" + + print(f"Loading trained model from: {trained_checkpoint_path}") + model = MMadaModelLM.from_pretrained( + trained_checkpoint_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + config='/home/work/AIDAS/ommda-training-s2t-mmada/config.json' + ) + print("āœ… Trained model loaded successfully!") + + # model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + + # # d) Extend vocabulary for speech tokens + num_speech_tokens = 4096 + image_vocab_size = config.model.mmada.codebook_size # 8192 + text_vocab_size = len(uni_prompting.text_tokenizer) + + # resize_vocab(model, config) + + model.to(device) + mask_token_id = model.config.mask_token_id + + temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions + top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability + audio_file_list = os.listdir(config.audio_dir) + audio_file_list = [f for f in audio_file_list if f.lower().endswith(('.wav', '.flac', '.mp3'))] + results_table = wandb.Table(columns=["Audio File", "Response"]) + + for file_name in tqdm(audio_file_list, desc="Processing Audio"): + audio_path = os.path.join(config.audio_dir, file_name) + with torch.no_grad(): + + speech_token_ids = vq_model.encode(audio_path).to(device) + print(speech_token_ids) + speech_token_ids += text_vocab_size + image_vocab_size + + input_ids = text_tokenizer( + ['<|start_header_id|>user<|end_header_id|>\n' + config.question +'<|start_header_id|>assistant<|end_header_id|>\n'], + return_tensors="pt" + ).input_ids.to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|s2t|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soa|>']).to(device), + speech_token_ids, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoa|>']).to(device), + # (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + # input_ids + ], dim=1).long() + + output_ids = model.mmu_generate(input_ids, max_new_tokens=512, steps=512, block_length=512) + + # print(output_ids[:, input_ids.shape[1]:]) + + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + print(f"\nFile: {file_name}\nResponse: {text}") + results_table.add_data(file_name, text) + + wandb.log({"Speech-to-Text Response": results_table}) \ No newline at end of file diff --git a/MMaDA/inference_s2t_WER.py b/MMaDA/inference_s2t_WER.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed2069a63231d878ac2b7d41cad7982e726411e --- /dev/null +++ b/MMaDA/inference_s2t_WER.py @@ -0,0 +1,397 @@ +# coding=utf-8 +# Copyright 2025 AIDAS Lab +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import editdistance +from functools import partial +from normalizer import data_utils + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +from tqdm import tqdm +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +import wandb +from datasets import load_dataset +from models import MMadaModelLM +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from training.data import S2T_INSTRUCTION +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf +from transformers import AutoTokenizer + +import argparse +import logging +import re + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +from tqdm import tqdm +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +import wandb +from datasets import load_dataset +from models import MMadaModelLM +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer + +from training.data import S2T_INSTRUCTION +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf +from transformers import AutoTokenizer + +def setup_logger(rank): + + logger = logging.getLogger(__name__) + + # ķ•øė“¤ėŸ¬ 중복 추가 ė°©ģ§€ + if logger.hasHandlers(): + logger.handlers.clear() + + formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + + ch = logging.StreamHandler() + ch.setFormatter(formatter) + + logger.addHandler(ch) + + if rank == 0: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + + return logger + +def calculate_WER(recognized_text_list, groundtruth_text_list): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + word_num = 0.0 + scores = 0.0 + for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): + + recognized_text = recognized_text.lower() + groundtruth_text = groundtruth_text.lower() + + recognized_text = re.sub(r"[^\w\s']", "", recognized_text) + groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text) + + recognized_word_list = recognized_text.split() + groundtruth_word_list = groundtruth_text.split() + + current_word_num = len(groundtruth_word_list) + word_num += current_word_num + + current_score = editdistance.eval(recognized_word_list, groundtruth_word_list) + scores += current_score + + WER = scores / word_num if word_num > 0 else 0.0 + return WER, scores, word_num + +def get_vq_model_class(model_type): + """Returns the speech tokenizer model class based on the model type.""" + if model_type == "magvitv2": + raise NotImplementedError("MAGVITv2 is not implemented in this script.") + elif model_type == "emova": + return EMOVASpeechTokenizer + else: + raise ValueError(f"model_type {model_type} not supported.") + +def get_librispeech_dataset(logger): + """Loads the Librispeech ASR dataset (test-clean split) from Hugging Face.""" + logger.info("Loading EMOVA dataset (clean/test)...") + dataset = load_dataset("Emova-ollm/emova-asr-tts-eval/", "librispeech-asr-tts")['test'] + logger.info("Dataset loaded successfully.") + return dataset + +def form_ann_rst_list(ann, results, key): + + ann_dict = {} + for item in ann: + if key in item['id']: + ann_dict[item['id']] = item['conversations'][-1]['value'] + + rst_dict = {} + for item in results: + if key in item['id']: + rst_dict[item['id']] = item['text'] + + return ann_dict, rst_dict + +# --- DDP Setup and Cleanup Functions --- + +def setup_distributed(rank, world_size): + """Initializes the distributed process group.""" + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup_distributed(): + """Cleans up the distributed process group.""" + dist.destroy_process_group() + +# --- Custom Dataset and Collate Function --- + +class LibrispeechEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path, vq_model, text_vocab_size, image_vocab_size): + self.hf_dataset = hf_dataset + self.root_path = root_path + self.vq_model = vq_model + self.text_vocab_size = text_vocab_size + self.image_vocab_size = image_vocab_size + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + gt_text = example['text'] + sample_id = example['id'] + + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + if not os.path.exists(audio_path): + return None + + speech_token_ids = self.vq_model.encode(audio_path) + speech_token_ids += self.text_vocab_size + self.image_vocab_size + + return { + "speech_token_ids": speech_token_ids, + "gt_text": gt_text, + "sample_id": sample_id + } + +def evaluation_collate_fn(batch, text_tokenizer, uni_prompting, config): + batch = [b for b in batch if b is not None] + if not batch: + return None + + device = batch[0]["speech_token_ids"].device + + max_text_len = config.dataset.preprocessing.max_seq_length + max_audio_len = config.dataset.preprocessing.max_aud_length + 1 + audio_pad_id = 126093 + sptids_dict = uni_prompting.sptids_dict + + batched_input_ids = [] + gt_texts = [item["gt_text"] for item in batch] + sample_ids = [item["sample_id"] for item in batch] + + for item in batch: + current_audio_tokens = item["speech_token_ids"].to(device) + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + + effective_max_audio = max_audio_len - 3 + if current_audio_tokens.shape[1] > effective_max_audio: + current_audio_tokens = current_audio_tokens[:, :effective_max_audio] + + audio_block = torch.cat([task_tensor, soa_tensor, current_audio_tokens, eoa_tensor], dim=1) + + num_padding = max_audio_len - audio_block.shape[1] + if num_padding > 0: + padding_tensor = torch.full((1, num_padding), audio_pad_id, dtype=torch.long, device=device) + padded_audio_block = torch.cat([padding_tensor, audio_block], dim=1) + else: + padded_audio_block = audio_block + + chosen_prompt = random.choice(S2T_INSTRUCTION) + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{chosen_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + + prompt_encoding = text_tokenizer( + prompt_text, + max_length=max_text_len, + truncation=True, + return_tensors="pt" + ) + prompt_tensor = prompt_encoding.input_ids.to(device) + + final_sequence = torch.cat([padded_audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + pad_token_id = 126093 + + max_len = max(seq.size(0) for seq in batched_input_ids) + + + final_batch = torch.full((len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device) + + + for i, seq in enumerate(batched_input_ids): + final_batch[i, -len(seq):] = seq + + return { + "input_ids": final_batch, + "gt_texts": gt_texts, + "sample_ids": sample_ids + } + +def main(): + """Main function to run the distributed evaluation.""" + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + setup_distributed(rank, world_size) + device = torch.device(f"cuda:{rank}") + + logger = setup_logger(rank) + parser = argparse.ArgumentParser(description="Run DDP evaluation for MMadaModelLM.") + parser.add_argument('--train_step', type=int, required=True, help='The training step of the checkpoint to evaluate.') + parser.add_argument('--remasking', type=str, default='random', help='Remasking Strategy.') + parser.add_argument('--generation_step', type=int, default=512, help='The training step of the checkpoint to evaluate.') + parser.add_argument('--new_tok', type=int, default=256, help='The training step of the checkpoint to evaluate.') + args, unknown = parser.parse_known_args() + config = get_config() + + if rank == 0: + run_id = config.wandb.get("run_id", None) or wandb.util.generate_id() + config.wandb.run_id = run_id + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb.init( + project="librispeech_test-clean", + name=config.experiment.name + f'_STEP-{args.train_step}_Remasking-{args.remasking}_GS-{args.generation_step}_NT-{args.new_tok}', + config=wandb_config, + ) + + text_tokenizer = AutoTokenizer.from_pretrained(config.model.omada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model_class = get_vq_model_class(config.model.vq_model_audio.type) + vq_model = vq_model_class.from_pretrained(config.model.vq_model_audio.vq_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + train_step = args.train_step + trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{train_step}/unwrapped_model/" + + if rank == 0: + logger.info(f"Loading trained model from: {trained_checkpoint_path}") + + model = MMadaModelLM.from_pretrained( + trained_checkpoint_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" + ).to(device) + + model = DDP(model, device_ids=[rank]) + if rank == 0: + logger.info("āœ… Trained model loaded and wrapped with DDP successfully!") + + text_vocab_size = len(uni_prompting.text_tokenizer) + image_vocab_size = config.model.omada.codebook_size + + # --- Setup DataLoader --- + hf_dataset = get_librispeech_dataset(logger) + root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean" + + eval_dataset = LibrispeechEvalDataset(hf_dataset, root_path, vq_model, text_vocab_size, image_vocab_size) + sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) + + collate_for_eval = partial( + evaluation_collate_fn, + text_tokenizer=text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + dataloader = DataLoader( + eval_dataset, + batch_size=16, + sampler=sampler, + num_workers=0, + collate_fn=collate_for_eval, + pin_memory=True + ) + + # --- Evaluation Loop --- + local_results = [] + model.eval() + + progress_bar = tqdm(dataloader, desc="Evaluating on Librispeech", disable=(rank != 0)) + for batch_idx, batch in enumerate(progress_bar): + # if batch_idx > 1: + # break + if batch is None: + continue + + input_ids = batch["input_ids"].to(device) + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + with torch.no_grad(): + output_ids = model.module.mmu_generate(input_ids, max_new_tokens=args.new_tok, steps=args.generation_step, block_length=args.new_tok, remasking=args.remasking) + decoded_texts = text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + # print(decoded_texts) + for i in range(len(decoded_texts)): + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": decoded_texts[i] + }) + + if rank == 0 and i == 0: + logger.info(f" ID: {sample_ids[i]}") + logger.info(f" GT: {data_utils.normalizer(gt_texts[i])}") + logger.info(f" PD: {data_utils.normalizer(decoded_texts[i])}") + + # --- Gather Results from All GPUs --- + all_results = [None] * world_size + dist.all_gather_object(all_results, local_results) + + # --- Final Processing and Logging (only on rank 0) --- + if rank == 0: + logger.info("Gathering and processing results from all GPUs...") + final_results = [item for sublist in all_results for item in sublist] + + groundtruth_text_list = [data_utils.normalizer(res["gt_text"]) for res in final_results] + recognized_text_list = [data_utils.normalizer(res["decoded_text"]) for res in final_results] + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Response"]) + for res in final_results: + results_table.add_data(res["sample_id"], res["gt_text"], res["decoded_text"]) + + wandb.log({"Speech-to-Text Response Examples": results_table}) + + wer, errors, words = calculate_WER(groundtruth_text_list, recognized_text_list) + logger.info(f"Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + wandb.log({ + "WER": wer, + "Total Word Errors": errors, + "Total Words": words + }) + + # --- Cleanup --- + if rank == 0: + wandb.finish() + cleanup_distributed() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MMaDA/inference_s2t_emova.py b/MMaDA/inference_s2t_emova.py new file mode 100644 index 0000000000000000000000000000000000000000..56b9f59a4dc1d3c47fcd8d0d4460d60f0f84aad4 --- /dev/null +++ b/MMaDA/inference_s2t_emova.py @@ -0,0 +1,358 @@ +# coding=utf-8 +# Copyright 2025 AIDAS Lab +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import editdistance +from functools import partial +import re +from normalizer import data_utils + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +from tqdm import tqdm +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +import wandb +from datasets import load_dataset +from models import OMadaModelLM + +from training.data import S2T_INSTRUCTION +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf +from transformers import AutoTokenizer + +import argparse +import logging + +def setup_logger(rank): + logger = logging.getLogger(__name__) + if logger.hasHandlers(): + logger.handlers.clear() + formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + if rank == 0: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + return logger + +def calculate_WER(recognized_text_list, groundtruth_text_list): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + word_num = 0.0 + scores = 0.0 + for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): + recognized_text = recognized_text.lower() + groundtruth_text = groundtruth_text.lower() + recognized_text = re.sub(r"[^\w\s']", "", recognized_text) + groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text) + + recognized_word_list = recognized_text.split() + groundtruth_word_list = groundtruth_text.split() + + current_word_num = len(groundtruth_word_list) + word_num += current_word_num + + current_score = editdistance.eval(recognized_word_list, groundtruth_word_list) + scores += current_score + + WER = scores / word_num if word_num > 0 else 0.0 + return WER, scores, word_num + +# ### REMOVED ###: No longer need this function +# def get_vq_model_class(model_type): ... + +def get_emova_dataset(logger): + """Loads the EMOVA ASR/TTS evaluation dataset from Hugging Face.""" + logger.info("Loading EMOVA dataset (librispeech-asr-tts config)...") + dataset = load_dataset("Emova-ollm/emova-asr-tts-eval", "librispeech-asr-tts", split='test') + dataset = dataset.filter(lambda example: 'asr' in example['id']) + logger.info(f"Dataset loaded successfully. Found {len(dataset)} ASR examples.") + return dataset + +def setup_distributed(rank, world_size): + """Initializes the distributed process group.""" + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup_distributed(): + """Cleans up the distributed process group.""" + dist.destroy_process_group() + +# ### MODIFIED ###: Dataset class now parses speech tokens from string +class EMOVAAsrEvalDataset(Dataset): + def __init__(self, hf_dataset, text_vocab_size, image_vocab_size): + self.hf_dataset = hf_dataset + self.text_vocab_size = text_vocab_size + self.image_vocab_size = image_vocab_size + # Pre-compile the regex for efficiency + self.speech_token_pattern = re.compile(r'<\|speech_(\d+)\|>') + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + + # Ground truth text is from the 'gpt' turn + gt_text = example['conversations'][-1]['value'] + sample_id = example['id'] + + # Audio tokens are in the 'human' turn as a string + audio_token_string = example['conversations'][0]['value'] + + # Parse the string to extract integer token IDs + speech_token_ids_str = self.speech_token_pattern.findall(audio_token_string) + # print(audio_token_string) + # print(speech_token_ids_str) + if not speech_token_ids_str: + return None # Handle cases with no speech tokens + + speech_token_ids = torch.tensor([int(s) for s in speech_token_ids_str], dtype=torch.long) + + # Shift audio token IDs to the correct range for the multimodal model's vocabulary + speech_token_ids += self.text_vocab_size + self.image_vocab_size + + return { + # Unsqueeze to add a batch dimension (consistent with original vq_model.encode output) + "speech_token_ids": speech_token_ids.unsqueeze(0), + "gt_text": gt_text, + "sample_id": sample_id + } + + +def evaluation_collate_fn(batch, text_tokenizer, uni_prompting, config): + batch = [b for b in batch if b is not None] + if not batch: + return None + + max_text_len = config.dataset.preprocessing.max_seq_length + max_audio_len = config.dataset.preprocessing.max_aud_length + 1 + audio_pad_id = 126093 + sptids_dict = uni_prompting.sptids_dict + + batched_input_ids = [] + gt_texts = [item["gt_text"] for item in batch] + sample_ids = [item["sample_id"] for item in batch] + + for item in batch: + current_audio_tokens = item["speech_token_ids"] + + task_tensor = sptids_dict['<|s2t|>'].to('cpu').unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to('cpu').unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to('cpu').unsqueeze(0) + + effective_max_audio = max_audio_len - 3 + if current_audio_tokens.shape[1] > effective_max_audio: + current_audio_tokens = current_audio_tokens[:, :effective_max_audio] + + audio_block = torch.cat([task_tensor, soa_tensor, current_audio_tokens, eoa_tensor], dim=1) + + num_padding = max_audio_len - audio_block.shape[1] + if num_padding > 0: + padding_tensor = torch.full((1, num_padding), audio_pad_id, dtype=torch.long) + padded_audio_block = torch.cat([padding_tensor, audio_block], dim=1) + else: + padded_audio_block = audio_block + + chosen_prompt = random.choice(S2T_INSTRUCTION) + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{chosen_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + + prompt_encoding = text_tokenizer( + prompt_text, + max_length=max_text_len, + truncation=True, + return_tensors="pt" + ) + prompt_tensor = prompt_encoding.input_ids + + final_sequence = torch.cat([padded_audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + pad_token_id = 126093 + max_len = max(seq.size(0) for seq in batched_input_ids) + final_batch = torch.full((len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long) + for i, seq in enumerate(batched_input_ids): + final_batch[i, -len(seq):] = seq + + return { + "input_ids": final_batch, + "gt_texts": gt_texts, + "sample_ids": sample_ids + } + +def main(): + """Main function to run the distributed evaluation.""" + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + setup_distributed(rank, world_size) + device = torch.device(f"cuda:{rank}") + + logger = setup_logger(rank) + parser = argparse.ArgumentParser(description="Run DDP evaluation for MMadaModelLM on EMOVA dataset.") + parser.add_argument('--train_step', type=int, required=True, help='WIP') + parser.add_argument('--remasking', type=str, default='random', help='Remasking Strategy.') + parser.add_argument('--generation_step', type=int, default=512, help='WIP') + parser.add_argument('--new_tok', type=int, default=128, help='WIP') + parser.add_argument('--block_length', type=int, default=64, help='WIP') + # parser.add_argument('--ckpt_path', type=str, required=True, help='WIP') + args, unknown = parser.parse_known_args() + config = get_config() + + if rank == 0: + run_id = config.wandb.get("run_id", None) or wandb.util.generate_id() + config.wandb.run_id = run_id + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb.init( + project="merging_grid", + name=f'{config.experiment.name}-STEP-{args.train_step}-Remasking-{args.remasking}-GS-{args.generation_step}-NT-{args.new_tok}', + config=wandb_config, + ) + + text_tokenizer = AutoTokenizer.from_pretrained(config.model.omada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + # ### REMOVED ###: VQ Model is not needed anymore + # vq_model_class = get_vq_model_class(config.model.vq_model_audio.type) + # vq_model = vq_model_class.from_pretrained(config.model.vq_model_audio.vq_model_name).to(device) + # vq_model.requires_grad_(False) + # vq_model.eval() + + train_step = args.train_step + # trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{train_step}/unwrapped_model/" + trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1_2nd/checkpoint-50000/unwrapped_model" + # trained_checkpoint_path = args.ckpt_path + + if rank == 0: + logger.info(f"Loading trained model from: {trained_checkpoint_path}") + + model = OMadaModelLM.from_pretrained( + trained_checkpoint_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" + ).to(device) + + print("BEFORE DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + print("AFTER DDP") + if rank == 0: + logger.info("āœ… Trained model loaded and wrapped with DDP successfully!") + + text_vocab_size = len(uni_prompting.text_tokenizer) + image_vocab_size = config.model.omada.codebook_size + + # --- Setup DataLoader --- + hf_dataset = get_emova_dataset(logger) + + # ### MODIFIED ###: Pass only necessary arguments to the dataset class + eval_dataset = EMOVAAsrEvalDataset(hf_dataset, text_vocab_size, image_vocab_size) + sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) + + collate_for_eval = partial( + evaluation_collate_fn, + text_tokenizer=text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + dataloader = DataLoader( + eval_dataset, + batch_size=16, + sampler=sampler, + num_workers=0, + collate_fn=collate_for_eval, + pin_memory=True + ) + + # --- Evaluation Loop --- + local_results = [] + model.eval() + + progress_bar = tqdm(dataloader, desc="Evaluating on EMOVA ASR", disable=(rank != 0)) + for batch in progress_bar: + if batch is None: + continue + + input_ids = batch["input_ids"].to(device) + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + # print(input_ids) + # print(gt_texts) + # print(sample_ids) + + with torch.no_grad(): + output_ids = model.module.mmu_generate(input_ids, max_new_tokens=args.new_tok, steps=args.generation_step, block_length=args.block_length, remasking=args.remasking) + decoded_texts = text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + for i in range(len(decoded_texts)): + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": decoded_texts[i] + }) + + if rank == 0 and i == 0 and len(local_results) % 10 == 1: + logger.info(f"\n--- Example ---") + logger.info(f" ID: {sample_ids[i]}") + logger.info(f" GT: {gt_texts[i]}") + logger.info(f" PD: {decoded_texts[i]}") + logger.info(f"-----------------\n") + + # --- Gather Results from All GPUs --- + all_results = [None] * world_size + dist.all_gather_object(all_results, local_results) + + # --- Final Processing and Logging (only on rank 0) --- + if rank == 0: + logger.info("Gathering and processing results from all GPUs...") + final_results = [item for sublist in all_results for item in sublist] + + groundtruth_text_list = [data_utils.normalizer(res["gt_text"]) for res in final_results] + recognized_text_list = [data_utils.normalizer(res["decoded_text"]) for res in final_results] + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Response"]) + for res in final_results: + results_table.add_data(res["sample_id"], res["gt_text"], res["decoded_text"]) + + wandb.log({"Speech-to-Text Response Examples": results_table}) + + wer, errors, words = calculate_WER(recognized_text_list, groundtruth_text_list) + logger.info(f"Final WER (EMOVA test): {wer:.4f} | Word Errors: {int(errors)} | Total Words: {int(words)}") + + wandb.log({ + "WER": wer, + "Total Word Errors": errors, + "Total Words": words + }) + + # --- Cleanup --- + if rank == 0: + wandb.finish() + cleanup_distributed() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MMaDA/inference_t2i.py b/MMaDA/inference_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..76d25d27a690174d8e7a3f4dc05a0984b09769d0 --- /dev/null +++ b/MMaDA/inference_t2i.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2025 MMaDA Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import inspect +import sys + +os.environ["TOKENIZERS_PARALLELISM"] = "true" +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +import wandb +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf, image_transform +from transformers import AutoTokenizer, AutoConfig, AutoModel +import torch.nn.functional as F + +def resize_vocab(model, config): + print(f"Resizing token embeddings to {config.new_vocab_size}") + model.resize_token_embeddings(config.new_vocab_size) + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + else: + raise ValueError(f"model_type {model_type} not supported.") + +if __name__ == '__main__': + + config = get_config() + + + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + + wandb.init( + project="demo", + name=config.experiment.name + '_t2i', + config=wandb_config, + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + print(vq_model) + sys.exit() + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + + + model.to(device) + + mask_token_id = model.config.mask_token_id + if config.get("validation_prompts_file", None) is not None: + config.dataset.params.validation_prompts_file = config.validation_prompts_file + config.training.batch_size = config.batch_size + config.training.guidance_scale = config.guidance_scale + config.training.generation_timesteps = config.generation_timesteps + + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + for step in tqdm(range(0, len(validation_prompts), config.training.batch_size)): + prompts = validation_prompts[step:step + config.training.batch_size] + + image_tokens = torch.ones((len(prompts), config.model.mmada.num_vq_tokens), + dtype=torch.long, device=device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, image_tokens), 't2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen') + else: + uncond_input_ids = None + uncond_attention_mask = None + + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + with torch.no_grad(): + gen_token_ids = model.t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=config.training.guidance_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + + gen_token_ids = torch.clamp(gen_token_ids, max=config.model.mmada.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)] + wandb.log({"generated_images": wandb_images}, step=step) diff --git a/MMaDA/inference_t2s.py b/MMaDA/inference_t2s.py new file mode 100644 index 0000000000000000000000000000000000000000..473b9c89afdcb396fac95f5a13bc355dd2655f29 --- /dev/null +++ b/MMaDA/inference_t2s.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2025 AIDAS Lab +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +os.environ["TOKENIZERS_PARALLELISM"] = "true" +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +import wandb +from models import MMadaModelLM +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf +from transformers import AutoTokenizer +import argparse + +def resize_vocab(model, config): + print(f"Resizing token embeddings to {config.model.mmada.new_vocab_size}") + model.resize_token_embeddings(config.model.mmada.new_vocab_size) + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +if __name__ == '__main__': + config = get_config() + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + + wandb.init( + project="demo", + name=config.experiment.name + '_t2s', + config=wandb_config, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + text_tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", "<|t2s|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + # b) Load speech tokenizer/detokenizer + vq_model = get_vq_model_class(config.model.speech_model.type) + vq_model = vq_model.from_pretrained(config.model.speech_model.speech_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + # c) Load main MMaDA model + train_step = config.model.mmada.train_step + trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{train_step}/unwrapped_model" + # trained_checkpoint_path = "/home/work/AIDAS/omada-training-stage1/checkpoint-10000/unwrapped_model" + + print(f"Loading trained model from: {trained_checkpoint_path}") + model = MMadaModelLM.from_pretrained( + trained_checkpoint_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + config='/home/work/AIDAS/ommda-training-s2t-mmada/config.json' # Should be changed to t2s after the train ends + ) + print("āœ… Trained model loaded successfully!") + + # model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + + # # d) Extend vocabulary for speech tokens + num_speech_tokens = 4096 + image_vocab_size = config.model.mmada.codebook_size # 8192 + # text_vocab_size = len(uni_prompting.text_tokenizer) + + # resize_vocab(model, config) + model.to(device).eval() + + mask_token_id = model.config.mask_token_id + if config.get("validation_prompts_file", None) is not None: + config.dataset.params.validation_prompts_file = config.validation_prompts_file + config.training.batch_size = config.batch_size + config.training.guidance_scale = config.guidance_scale + config.training.generation_timesteps = config.generation_timesteps + + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + for step in tqdm(range(0, len(validation_prompts), config.training.batch_size)): + prompts = validation_prompts[step:step + config.training.batch_size] + + audio_tokens = torch.ones((len(prompts), config.model.mmada.num_speech_vq_tokens), + dtype=torch.long, device=device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(prompts), audio_tokens), 't2s_gen') + else: + uncond_input_ids = None + uncond_attention_mask = None + + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + with torch.no_grad(): + # TODO: Implement t2s_generate + gen_token_ids = model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=config.training.guidance_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + seq_len=config.model.mmada.num_speech_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + + gen_token_ids = torch.clamp(gen_token_ids, max=config.model.mmada.speech_codebook_size - 1, min=0) + id_list = gen_token_ids[0].cpu().tolist() + print(len(id_list)) + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + + output_wav_path = f"/home/work/AIDAS/output/omada_tmp/generated_audio_step_{train_step}_{step}_item.wav" + # Using a default condition, this can be made more dynamic if needed + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + vq_model.decode( + speech_unit_for_decode, + condition=condition, + output_wav_file=output_wav_path + ) + + wandb.log({ + f"Generated Audio/{step*config.training.batch_size}": wandb.Audio(output_wav_path, caption=prompts) + }, step=step) \ No newline at end of file diff --git a/MMaDA/inference_t2s_emova.py b/MMaDA/inference_t2s_emova.py new file mode 100644 index 0000000000000000000000000000000000000000..8f22b33e36a4a7f3d6373c27522e05d012729d3c --- /dev/null +++ b/MMaDA/inference_t2s_emova.py @@ -0,0 +1,275 @@ +# coding=utf-8 +# Copyright 2025 AIDAS Lab + +import os +import random +import editdistance +from functools import partial +import re +import soundfile as sf +import numpy as np + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +from tqdm import tqdm +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +import wandb +from datasets import load_dataset +from models import OMadaModelLM +from training.data import T2S_INSTRUCTION # T2S_INSTRUCTION import +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf +from models import get_mask_schedule +from transformers import AutoTokenizer, pipeline + +import argparse +import logging + +# --- (setup_logger, calculate_WER, get_emova_dataset_tts, EMOVATtsEvalDataset, setup_distributed, cleanup_distributed ķ•Øģˆ˜ėŠ” ģ“ģ „ź³¼ ė™ģ¼) --- +def setup_logger(rank): + logger = logging.getLogger(__name__) + if logger.hasHandlers(): + logger.handlers.clear() + formatter = logging.Formatter(f'%(asctime)s - [RANK {rank}] - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + if rank == 0: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + return logger + +def calculate_WER(recognized_text_list, groundtruth_text_list): + word_num = 0.0 + scores = 0.0 + for recognized_text, groundtruth_text in zip(recognized_text_list, groundtruth_text_list): + recognized_text = recognized_text.lower(); groundtruth_text = groundtruth_text.lower() + recognized_text = re.sub(r"[^\w\s']", "", recognized_text); groundtruth_text = re.sub(r"[^\w\s']", "", groundtruth_text) + recognized_word_list = recognized_text.split(); groundtruth_word_list = groundtruth_text.split() + current_word_num = len(groundtruth_word_list); word_num += current_word_num + current_score = editdistance.eval(recognized_word_list, groundtruth_word_list); scores += current_score + WER = scores / word_num if word_num > 0 else 0.0 + return WER, scores, word_num + +def get_emova_dataset_tts(logger): + logger.info("Loading EMOVA dataset (librispeech-asr-tts config) for TTS...") + dataset = load_dataset("Emova-ollm/emova-asr-tts-eval", "librispeech-asr-tts", split='test') + original_count = len(dataset) + dataset = dataset.filter( + lambda example: 'tts' in example['id'] and '", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + if rank == 0: + logger.info("Loading Whisper model for evaluation...") + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=device) + logger.info("Whisper model loaded.") + + logger.info("Loading EMOVA VQ model (vocoder)...") + vq_model = EMOVASpeechTokenizer.from_pretrained(config.model.vq_model_audio.vq_model_name).to(device) + vq_model.eval() + logger.info("EMOVA VQ model loaded.") + + # trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{args.train_step}/unwrapped_model/" + # trained_checkpoint_path = "/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-145000/unwrapped_model" + trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1_2nd/checkpoint-45000/unwrapped_model" + + if rank == 0: + logger.info(f"Loading trained MMada model from: {trained_checkpoint_path}") + + model = OMadaModelLM.from_pretrained( + trained_checkpoint_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" + ).to(device) + + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + logger.info("āœ… Trained MMada model loaded and wrapped with DDP successfully!") + + hf_dataset = get_emova_dataset_tts(logger) + eval_dataset = EMOVATtsEvalDataset(hf_dataset) + sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank, shuffle=False) + # ### CORRECTED ###: Remove custom collate_fn, default is sufficient + dataloader = DataLoader(eval_dataset, batch_size=16, sampler=sampler, num_workers=0) + + local_results = [] + model.eval() + + mask_token_id = 126336 + + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + schedule_args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **schedule_args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + progress_bar = tqdm(dataloader, desc="Evaluating TTS on EMOVA", disable=(rank != 0)) + for batch_idx, batch in enumerate(progress_bar): + if batch is None: + continue + + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # ### CORRECTED & SIMPLIFIED PROMPT PREPARATION ### + prompts = [] + for text in gt_texts: + text = text.rsplit("\n", 1)[-1].strip() + chosen_prompt = random.choice(T2S_INSTRUCTION) + full_instruction = f"{text}\n{chosen_prompt}" # Combine instruction and text + prompts.append(full_instruction) + + print(prompts[0]) + batch_size = len(prompts) + + # Using speech_token_length from args + print(args.speech_token_length -1) + audio_tokens = torch.ones((batch_size, args.speech_token_length -1 ), dtype=torch.long, device=device) * mask_token_id # 99 tokens + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if args.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + with torch.no_grad(): + # ### CORRECTED t2s_generate call with proper arguments ### + output_ids = model.module.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=args.guidance_scale, + temperature=1.0, # Hardcoded temperature as example + timesteps=args.timesteps, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=args.speech_token_length, + uni_prompting=uni_prompting, + config=config, + ) + + if rank == 0: + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_token_ids = output_ids[i] + + # print(gt) + # print(gen_token_ids) + + clamped_ids = torch.clamp(gen_token_ids, max=4096 - 1, min=0) + id_list = clamped_ids.cpu().tolist() + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + output_wav_path = f"/home/work/AIDAS/t2s_logs/tts_output_{sample_ids[i]}.wav" + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + vq_model.decode( + speech_unit_for_decode, + condition=condition, + output_wav_file=output_wav_path + ) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text + }) + + if i == 0: + logger.info(f"\n--- TTS Example (Batch {batch_idx}) ---") + logger.info(f" ID: {sample_ids[i]}; GT: {gt}; Whisper: {whisper_text}") + logger.info(f" (Audio saved to {output_wav_path})") + + wandb.log({ + f"Generated Audio/{sample_ids[i]}": wandb.Audio(output_wav_path, caption=f"ID: {sample_ids[i]}\nGT: {gt}\nWhisper: {whisper_text}") + }) + + all_results = [None] * world_size + dist.all_gather_object(all_results, local_results) + + if rank == 0: + final_results = [item for sublist in all_results for item in sublist if item is not None] + if final_results: + groundtruth_text_list = [res["gt_text"] for res in final_results] + recognized_text_list = [res["whisper_text"] for res in final_results] + + results_table = wandb.Table(columns=["ID", "Ground Truth Text", "Whisper Transcription"]) + for res in final_results: results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"]) + wandb.log({"Text-to-Speech Whisper Transcriptions": results_table}) + + wer, errors, words = calculate_WER(recognized_text_list, groundtruth_text_list) + logger.info(f"Final TTS WER (via Whisper): {wer:.4f} | Word Errors: {int(errors)} | Total Words: {int(words)}") + wandb.log({"TTS WER (via Whisper)": wer, "Total Word Errors": errors, "Total Words": words}) + else: + logger.warning("No results were generated to calculate WER.") + + # ### CRITICAL FIX: DDP Cleanup MUST be called by all processes ### + cleanup_distributed() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MMaDA/inference_v2t.py b/MMaDA/inference_v2t.py new file mode 100644 index 0000000000000000000000000000000000000000..21bc96f71d11e89b229035a358e53c1d39c6d59e --- /dev/null +++ b/MMaDA/inference_v2t.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2025 AIDAS Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ["TOKENIZERS_PARALLELISM"] = "true" +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +import wandb +import cv2 +from models import MAGVITv2, MMadaConfig, MMadaModelLM +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf, image_transform +from transformers import AutoTokenizer, AutoConfig + +def resize_vocab(model, config): + print(f"Resizing token embeddings to {config.new_vocab_size}") + model.resize_token_embeddings(config.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + else: + raise ValueError(f"model_type {model_type} not supported.") + +def inference_video(): + + pass + +def load_video( + video_path, + config, + uni_prompting, + vq_model=None, + device='cuda', + sample='uniform', + num_frames=8 + ): + """ + args: + video_path: path to the video file + return: video frames as a list of images + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Could not open video file {video_path}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(Image.fromarray(frame)) + + cap.release() + + total_frames = len(frames) + + if total_frames < num_frames: + raise ValueError(f"Video {video_path} has less than 8 frames, got {total_frames} frames.") + + if sample == 'uniform': + indices = np.linspace(0, total_frames - 1, num_frames).astype(int) + elif sample == 'random': + raise NotImplementedError("Random sampling not implemented yet.") + else: + raise ValueError(f"Sampling method {sample} not supported.") + + sampled_frames = [] + sampled_frames_tokens = [] + for idx in indices: + frame = frames[idx] + frame = image_transform(frame, resolution=config.dataset.params.resolution).to(device) + sampled_frames.append(frame.unsqueeze(0)) + sampled_frames_tokens.append( + vq_model.get_code(frame.unsqueeze(0)) + len(uni_prompting.text_tokenizer) + ) + + # num_frames * [num_frames, seq_len] -> [1, num_frames * seq_len] + video_tokens = torch.cat(sampled_frames_tokens, dim=1) + + return sampled_frames, video_tokens + + + +def main(): + config = get_config() + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + + wandb.init( + project="demo", + name=config.experiment.name + '_video', + config=wandb_config, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", "<|v2t|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + train_step = config.step + trained_checkpoint_path = f"/home/work/AIDAS/ckpts/omada/omada-training-stage1/checkpoint-{train_step}/unwrapped_model" + + model = MMadaModelLM.from_pretrained(trained_checkpoint_path, trust_remote_code=True, torch_dtype=torch.bfloat16, config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json") + # model = MMadaModelLM.from_pretrained("Gen-Verse/MMaDA-8B-MixCoT", trust_remote_code=True, torch_dtype=torch.bfloat16) + model.to(device) + + mask_token_id = model.config.mask_token_id + + temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions + top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability + file_list = os.listdir(config.video_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.mp4'))] + responses = ['' for i in range(len(file_list))] + videos = [] + config.question = config.question.split(' *** ') + for i, file_name in enumerate(tqdm(file_list)): + video_path = os.path.join(config.video_image_root, file_name) + print("current video path:", video_path) + video_frames, video_tokens = load_video( + video_path, + config, + uni_prompting, + vq_model=vq_model, + device=device, + sample='uniform', + num_frames=8 + ) + print("video tokens shape:", video_tokens.shape) + batch_size = 1 + + for question in config.question: + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + question +'<|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|v2t|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + video_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + print(f"input_ids shape: {input_ids.shape}") + + output_ids = model.mmu_generate(input_ids, max_new_tokens=128, steps=128, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + print(text) + responses[i] += f'User: ' + question + f'\n Answer : ' + text[0] + '\n' + + # images = torch.cat(images, dim=0) + # images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + # images *= 255.0 + # images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + # pil_images = [Image.fromarray(image) for image in images] + + # wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] + # wandb.log({"multimodal understanding": wandb_images}, step=0) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MMaDA/inference_v2t_temp.py b/MMaDA/inference_v2t_temp.py new file mode 100644 index 0000000000000000000000000000000000000000..1efacbdece14f25bafc0f712d4677d2d8be52908 --- /dev/null +++ b/MMaDA/inference_v2t_temp.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2025 AIDAS Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ["TOKENIZERS_PARALLELISM"] = "true" +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +import wandb +import cv2 +from models import MAGVITv2, MMadaConfig, MMadaModelLM +from training.prompting_utils import UniversalPrompting +from training.utils import get_config, flatten_omega_conf, image_transform +from transformers import AutoTokenizer, AutoConfig + +def resize_vocab(model, config): + print(f"Resizing token embeddings to {config.new_vocab_size}") + model.resize_token_embeddings(config.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + else: + raise ValueError(f"model_type {model_type} not supported.") + +def inference_video(): + + pass + +def load_video( + video_path, + config, + uni_prompting, + vq_model=None, + device='cuda', + sample='uniform', + num_frames=4 + ): + """ + args: + video_path: path to the video file + return: video frames as a list of images + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Could not open video file {video_path}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(Image.fromarray(frame)) + + cap.release() + + total_frames = len(frames) + + if total_frames < num_frames: + raise ValueError(f"Video {video_path} has less than 8 frames, got {total_frames} frames.") + + if sample == 'uniform': + indices = np.linspace(0, total_frames - 1, num_frames).astype(int) + elif sample == 'random': + raise NotImplementedError("Random sampling not implemented yet.") + else: + raise ValueError(f"Sampling method {sample} not supported.") + + sampled_frames = [] + sampled_frames_tokens = [] + for idx in indices: + frame = frames[idx] + frame = image_transform(frame, resolution=config.dataset.params.resolution).to(device) + sampled_frames.append(frame.unsqueeze(0)) + sampled_frames_tokens.append( + torch.cat([ + (torch.ones(1, 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + vq_model.get_code(frame.unsqueeze(0)) + len(uni_prompting.text_tokenizer), + (torch.ones(1, 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device) + ], dim=1) # dim=1ģ“ė©“ token axis 기준 concat + ) + + # num_frames * [num_frames, seq_len] -> [1, num_frames * seq_len] + video_tokens = torch.cat(sampled_frames_tokens, dim=1) + + return sampled_frames, video_tokens + + + +def main(): + config = get_config() + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + + wandb.init( + project="demo", + name=config.experiment.name + '_video', + config=wandb_config, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) + vq_model.requires_grad_(False) + vq_model.eval() + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + model.to(device) + + mask_token_id = model.config.mask_token_id + + temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions + top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability + file_list = os.listdir(config.video_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.mp4'))] + responses = ['' for i in range(len(file_list))] + videos = [] + config.question = config.question.split(' *** ') + for i, file_name in enumerate(tqdm(file_list)): + video_path = os.path.join(config.video_image_root, file_name) + print("current video path:", video_path) + video_frames, video_tokens = load_video( + video_path, + config, + uni_prompting, + vq_model=vq_model, + device=device, + sample='uniform', + num_frames=8 + ) + print("video tokens shape:", video_tokens.shape) + batch_size = 1 + + # print(video_tokens) + + for question in config.question: + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + question +'<|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + # (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + video_tokens, + # (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + + # print(f"input_ids shape: {input_ids.shape}") + # print(f"input_ids: {input_ids}") + + output_ids = model.mmu_generate(input_ids, max_new_tokens=1024, steps=512, block_length=64) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + print(text) + responses[i] += f'User: ' + question + f'\n Answer : ' + text[0] + '\n' + + # images = torch.cat(images, dim=0) + # images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + # images *= 255.0 + # images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + # pil_images = [Image.fromarray(image) for image in images] + + # wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] + # wandb.log({"multimodal understanding": wandb_images}, step=0) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MMaDA/lm_chat_validation/description.txt b/MMaDA/lm_chat_validation/description.txt new file mode 100644 index 0000000000000000000000000000000000000000..6a6136bac694da374c9f1fb427aaf3c451c2a0bc --- /dev/null +++ b/MMaDA/lm_chat_validation/description.txt @@ -0,0 +1,5 @@ +<|start_header_id|>user<|end_header_id|> +From the following items, select the one that belongs to animals: +1. Apple +2. Sun +3. Dog<|start_header_id|>assistant<|end_header_id|> diff --git a/MMaDA/lm_chat_validation/questions.jsonl b/MMaDA/lm_chat_validation/questions.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..676484aced54b32426cecfa29d1d77e81125ca1d --- /dev/null +++ b/MMaDA/lm_chat_validation/questions.jsonl @@ -0,0 +1,11 @@ +{"question":"Write a short poem with the theme of the sea."} +{"question":"From the following items, select the one that belongs to animals:\n1. Apple\n2. Sun\n3. Dog"} +{"question":"Please answer the following question based on the context provided.\nContext: \nGood Friday is a Christian holiday commemorating the crucifixion of Jesus and his death at Calvary. It is observed during Holy Week as part of the Paschal Triduum. It is also known as Holy Friday, Great Friday, Great and Holy Friday (also Holy and Great Friday), and Black Friday.\nQuestion: \nExtract the various ways to say Good Friday from the text. Separate them with a new line."} +{"question":"Write a speech introducing yourself to the audience."} +{"question":"Please answer the following question based on the context provided.\nContext:\nThe Maurice \"Rocket\" Richard Trophy, also known as the Rocket Richard Trophy, is awarded annually to the leading goal scorer in the National Hockey League (NHL). It was donated to the NHL by the Montreal Canadiens in 1998–99 and is named in honour of legendary Montreal Canadiens right winger Maurice \"Rocket\" Richard. First won by Teemu Selanne, it is currently held by Auston Matthews, who scored 60 goals during the 2021–22 NHL season.\nQuestion:\nWhat is the Maurice Richard Trophy"} +{"question":"Explain what an embedding layer is and its purpose in Machine Learning."} +{"question":"You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\nA rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?\n"} +{"question":"You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\nEvaluate $ (1 + i)^4 $.\n"} +{"question":"You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\nGiven $\\tan\\beta= \\frac {1}{2}$, find the value of $\\sin^2\\beta-3\\sin\\beta\\cos\\beta+4\\cos^2\\beta$.\n"} +{"question":"You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\nJames has 7 apples. 4 of them are red, and 3 of them are green. If he chooses 2 apples at random, what is the probability that both the apples he chooses are green?\n"} +{"question":"You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\nThe user will describe something indirectly, and you need to infer and answer what that thing is (without any explanation). If there are multiple possible answers, choose one of them.\nThe thing: A staple food in many Asian countries\n"} \ No newline at end of file diff --git a/MMaDA/logs/v2t_None.log b/MMaDA/logs/v2t_None.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MMaDA/logs/v2t_c8fnnhj0.log b/MMaDA/logs/v2t_c8fnnhj0.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MMaDA/merging.py b/MMaDA/merging.py new file mode 100644 index 0000000000000000000000000000000000000000..03700838127845b65560aa684203556f8ce70f69 --- /dev/null +++ b/MMaDA/merging.py @@ -0,0 +1,273 @@ +import argparse +import torch +from transformers import AutoTokenizer +from safetensors.torch import load_file +import os +from typing import Dict, Any, Type, Set, Optional + +# Modify this to match the path of your 'models.py' +from models import OMadaModelLM, MMadaModelLM + +def merge_models_safetensors( + hf_model_name: str, + local_model_directory: str, + output_dir: str, + alpha: float = 0.5, + merge_vocab: bool = True, + hf_for_common: bool = False, + task_vector_scale: Optional[float] = None, +): + """ + Merges a Hugging Face model with a local sharded safetensors checkpoint. + + Args: + hf_model_name (str): The name or path of the base model from Hugging Face. + local_model_directory (str): The directory path containing the local sharded safetensors files. + output_dir (str): The directory to save the merged model and tokenizer. + alpha (float): The weighting factor for the Hugging Face model. + merge_vocab (bool): If True, merges common token embeddings. + hf_for_common (bool): If True, uses HF model weights for common token embeddings. + """ + print("--- Starting model merge process ---") + print(f"Base Hugging Face Model: {hf_model_name}") + print(f"Local Model Directory (Safetensors shards): {local_model_directory}") + print(f"Merge Alpha: {alpha}") + print(f"Merge Vocabulary (for common tokens): {merge_vocab}") + print(f"Use HF for common tokens: {hf_for_common}") + if task_vector_scale is not None: + print(f"Task vector scale (overlapping parameters): {task_vector_scale}") + + torch_dtype = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "auto": torch.float16, + }.get(str("bfloat16").lower(), torch.bfloat16) + + model_class_hf: Type[torch.nn.Module] = MMadaModelLM + model_class_local: Type[torch.nn.Module] = OMadaModelLM + + try: + hf_model = model_class_hf.from_pretrained( + hf_model_name, + trust_remote_code=True, + torch_dtype=torch_dtype, + ).eval() + hf_state_dict = hf_model.state_dict() + print(f"Successfully loaded Hugging Face model '{hf_model_name}'.") + except Exception as e: + print(f"Error loading Hugging Face model '{hf_model_name}': {e}") + return + + try: + local_model = model_class_local.from_pretrained( + local_model_directory, + trust_remote_code=True, + torch_dtype=torch_dtype, + config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" + ).eval() + local_state_dict = local_model.state_dict() + print(f"Successfully loaded local model from directory: '{local_model_directory}'.") + except Exception as e: + print(f"Error loading local model from directory '{local_model_directory}': {e}") + print("Please ensure the directory contains all safetensors shards and a valid config.json.") + return + + merged_state_dict = {} + all_keys = set(hf_state_dict.keys()).union(set(local_state_dict.keys())) + + print("Merging model weights...") + for key in all_keys: + if key in ["model.transformer.wte.weight", "model.transformer.ff_out.weight"]: + # --- Merge logic for embedding and final projection layers --- + + # 1. Use HF for common tokens, local for expanded portion (--hf-for-common) + if hf_for_common: + hf_weights = hf_state_dict[key] + local_weights = local_state_dict[key] + + if hf_weights.shape[0] < local_weights.shape[0]: + print(f"Detected embedding size mismatch for key '{key}'. Merging by tensor indices.") + merged_weights = torch.zeros_like(local_weights) + common_size = hf_weights.shape[0] + + # Copy HF weights for the common vocabulary portion (no interpolation) + merged_weights[:common_size, :] = hf_weights + + # Copy local weights for the expanded embedding space + merged_weights[common_size:, :] = local_weights[common_size:, :] + + merged_state_dict[key] = merged_weights + print(f"Successfully merged {common_size} old and {local_weights.shape[0] - common_size} new embeddings.") + continue + else: + print(f"Embedding sizes are identical for key '{key}'. Proceeding with standard vocab merge logic.") + + # 2. Use local model weights for all tokens (--no-merge-vocab) + elif not merge_vocab: + print(f"Vocab merge disabled. Using local model weights for key '{key}'.") + merged_state_dict[key] = local_state_dict[key] + continue + + # 3. Average common tokens, use local for new tokens (default) + print(f"Merging common tokens for key '{key}'.") + + try: + hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_name, trust_remote_code=True) + local_tokenizer = AutoTokenizer.from_pretrained(local_model_directory, trust_remote_code=True) + except Exception as e: + print(f"Error loading tokenizers for special merge handling: {e}") + if key in hf_state_dict and key in local_state_dict: + if hf_state_dict[key].shape == local_state_dict[key].shape: + merged_weights = (alpha * hf_state_dict[key]) + ((1 - alpha) * local_state_dict[key]) + merged_state_dict[key] = merged_weights + else: + merged_state_dict[key] = local_state_dict[key] + continue + + hf_vocab_ids: Set[int] = set(hf_tokenizer.get_vocab().values()) + local_vocab_ids: Set[int] = set(local_tokenizer.get_vocab().values()) + + common_vocab_ids = hf_vocab_ids.intersection(local_vocab_ids) + new_vocab_ids = local_vocab_ids.difference(hf_vocab_ids) + + print(f"Found {len(common_vocab_ids)} common tokens and {len(new_vocab_ids)} new tokens.") + + hf_weights = hf_state_dict[key] + local_weights = local_state_dict[key] + + merged_weights = local_weights.clone() + + common_indices = torch.tensor(sorted(common_vocab_ids), dtype=torch.long) + if task_vector_scale is not None: + merged_weights[common_indices] = hf_weights[common_indices] + task_vector_scale * (local_weights[common_indices] - hf_weights[common_indices]) + else: + merged_weights[common_indices] = ( + alpha * hf_weights[common_indices] + + (1 - alpha) * local_weights[common_indices] + ) + + new_indices = torch.tensor(sorted(new_vocab_ids), dtype=torch.long) + merged_weights[new_indices] = local_weights[new_indices] + + merged_state_dict[key] = merged_weights + + elif key in hf_state_dict and key in local_state_dict: + if hf_state_dict[key].shape == local_state_dict[key].shape: + if task_vector_scale is not None: + merged_weights = hf_state_dict[key] + task_vector_scale * (local_state_dict[key] - hf_state_dict[key]) + else: + merged_weights = (alpha * hf_state_dict[key]) + ((1 - alpha) * local_state_dict[key]) + merged_state_dict[key] = merged_weights + else: + hf_shape = hf_state_dict[key].shape + local_shape = local_state_dict[key].shape + if len(hf_shape) > 0 and hf_shape[0] <= local_shape[0] and hf_shape[1:] == local_shape[1:]: + print(f"Key '{key}' has expanded leading dimension. Merging common portion and keeping local extras.") + merged_weights = local_state_dict[key].clone() + common_size = hf_shape[0] + merged_weights[:common_size] = hf_state_dict[key][:common_size] + merged_state_dict[key] = merged_weights + else: + print(f"Warning: Key '{key}' has mismatched shapes (hf={hf_shape}, local={local_shape}). Using local model's weights.") + merged_state_dict[key] = local_state_dict[key] + elif key in hf_state_dict: + merged_state_dict[key] = hf_state_dict[key] + elif key in local_state_dict: + merged_state_dict[key] = local_state_dict[key] + + print(f"Successfully processed {len(all_keys)} parameters for merging.") + + print("Loading merged weights into a new model instance...") + try: + new_model = model_class_local.from_pretrained( + local_model_directory, + trust_remote_code=True, + torch_dtype=torch_dtype, + config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" + ).eval() + + new_model.load_state_dict(merged_state_dict, strict=False) + print("Merged weights successfully loaded into new model.") + except Exception as e: + print(f"Error loading merged state dict into new model: {e}") + return + + try: + os.makedirs(output_dir, exist_ok=True) + new_model.save_pretrained(output_dir, safe_serialization=True) + + local_tokenizer = AutoTokenizer.from_pretrained(local_model_directory, trust_remote_code=True) + local_tokenizer.save_pretrained(output_dir) + + print(f"--- Model successfully merged and saved to '{output_dir}' ---") + except Exception as e: + print(f"Error saving the merged model and tokenizer: {e}") + +# --- Example Usage --- +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Merge a Hugging Face model with a local checkpoint.") + parser.add_argument( + "--alpha", + type=float, + default=999, + help="The weighting factor for the Hugging Face model. 0.5 for 50/50 average." + ) + # Existing '--no-merge-vocab' flag + parser.add_argument( + "--no-merge-vocab", + action="store_false", + dest="merge_vocab", + help="Do not merge overlapping token embeddings; use local model's embeddings as-is." + ) + # New '--hf-for-common' flag added + parser.add_argument( + "--hf-for-common", + action="store_true", + help="Use HF model weights for common tokens and local weights for expanded space." + ) + parser.add_argument( + "--task-vector-scale", + type=float, + default=None, + help="Scale factor for the task vector (local - HF) over overlapping parameters.", + ) + + args = parser.parse_args() + + MERGE_ALPHA = args.alpha + MERGE_VOCAB = args.merge_vocab + HF_FOR_COMMON = args.hf_for_common + TASK_VECTOR_SCALE = args.task_vector_scale + + if not MERGE_VOCAB and HF_FOR_COMMON: + print("Error: You cannot use --no-merge-vocab and --hf-for-common at the same time.") + exit(1) + + HF_MODEL_NAME = "Gen-Verse/MMaDA-8B-MixCoT" + LOCAL_MODEL_DIRECTORY = "/home/work/AIDAS/ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model/" + + # Dynamic output directory based on flags + if HF_FOR_COMMON: + SUB_DIR = "hf_common_merge" + elif not MERGE_VOCAB: + SUB_DIR = "no_vocab_merge" + else: + SUB_DIR = "average_merge" + + scale_suffix = "" + if TASK_VECTOR_SCALE is not None: + scale_str = str(TASK_VECTOR_SCALE).replace(".", "p") + scale_suffix = f"_scale_{scale_str}" + + OUTPUT_DIRECTORY = f"/home/work/AIDAS/ckpts/merged_model/{SUB_DIR}_alpha_{MERGE_ALPHA}{scale_suffix}" + + merge_models_safetensors( + hf_model_name=HF_MODEL_NAME, + local_model_directory=LOCAL_MODEL_DIRECTORY, + output_dir=OUTPUT_DIRECTORY, + alpha=MERGE_ALPHA, + merge_vocab=MERGE_VOCAB, + hf_for_common=HF_FOR_COMMON, + task_vector_scale=TASK_VECTOR_SCALE, + ) diff --git a/MMaDA/merging.sh b/MMaDA/merging.sh new file mode 100644 index 0000000000000000000000000000000000000000..00d1df07a7fc289413977e083ea3549d7b95c1ee --- /dev/null +++ b/MMaDA/merging.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# Path to the Python merging script +PYTHON_SCRIPT="/home/work/AIDAS/MMaDA/merging.py" + +# Define the range of alpha values for grid search +ALPHA_VALUES=() + +# Optional task-vector scaling factors (set to an empty array to disable) +TASK_VECTOR_SCALES=(0.5 0.6 0.7 0.8 0.9) + +MERGE_STRATEGIES=("hf-for-common") + +# Ensure there is at least one placeholder iteration even if alpha list is empty +if [ ${#ALPHA_VALUES[@]} -eq 0 ]; then + ALPHA_VALUES=("default") +fi + +echo "Starting grid search for model merging..." +echo "Alpha values: ${ALPHA_VALUES[@]}" +echo "Task vector scales: ${TASK_VECTOR_SCALES[@]}" +echo "Merge strategies: ${MERGE_STRATEGIES[@]}" +echo "----------------------------------------" + +for alpha in "${ALPHA_VALUES[@]}"; do + for task_scale in "${TASK_VECTOR_SCALES[@]}"; do + for strategy in "${MERGE_STRATEGIES[@]}"; do + + task_arg="" + if [ -n "${task_scale}" ]; then + task_arg="--task-vector-scale ${task_scale}" + fi + + alpha_arg="" + alpha_label="${alpha}" + if [ "${alpha}" != "default" ] && [ -n "${alpha}" ]; then + alpha_arg="--alpha ${alpha}" + else + alpha_label="(script default)" + fi + + # Construct the command based on the current strategy + if [ "${strategy}" == "default" ]; then + command="python ${PYTHON_SCRIPT} ${alpha_arg} ${task_arg}" + description="Strategy: Default (Average Merge)" + elif [ "${strategy}" == "no-merge-vocab" ]; then + command="python ${PYTHON_SCRIPT} ${alpha_arg} --no-merge-vocab ${task_arg}" + description="Strategy: No Vocab Merge" + elif [ "${strategy}" == "hf-for-common" ]; then + command="python ${PYTHON_SCRIPT} ${alpha_arg} --hf-for-common ${task_arg}" + description="Strategy: HF For Common Tokens" + fi + + echo "Executing command with alpha=${alpha_label}, task_scale=${task_scale}, ${description}" + echo "Command: ${command}" + + ${command} + + echo "----------------------------------------" + done + done +done + +echo "All grid search runs completed." diff --git a/MMaDA/mmada-training-stage3-llada-instruct-v2t-special-token-1e-5/config.yaml b/MMaDA/mmada-training-stage3-llada-instruct-v2t-special-token-1e-5/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..84a1620757eab89cf00649b934085f20d41ce20d --- /dev/null +++ b/MMaDA/mmada-training-stage3-llada-instruct-v2t-special-token-1e-5/config.yaml @@ -0,0 +1,78 @@ +wandb: + entity: null + resume: auto + run_id: 39gv7phb +experiment: + project: mmada-training-v2t + name: mmada-training-stage3-llada-instruct-v2t + output_dir: mmada-training-stage3-llada-instruct-v2t-special-token-1e-5 + max_train_examples_t2i: 40000000 + max_train_examples_mmu: 40000000 ddd# + save_every: 1000 + eval_every: 2500 + generate_every: 1000 + log_every: 10 + log_grad_norm_every: 100 + resume_from_checkpoint: latest + val_every: 50 + max_val_examples_t2i: 2000 + logging_dir: mmada-training-stage3-llada-instruct-v2t-special-token-1e-5/logs +model: + vq_model: + type: magvitv2 + vq_model_name: showlab/magvitv2 + mmada: + tokenizer_path: GSAI-ML/LLaDA-8B-Instruct + pretrained_model_path: Gen-Verse/MMaDA-8B-Base + w_clip_vit: false + new_vocab_size: 134656 + llm_vocab_size: 126464 + codebook_size: 8192 + num_vq_tokens: 256 + num_new_special_tokens: 0 + tie_word_embeddings: false + gradient_checkpointing: true +dataset: + und_type: captioning + combined_loader_mode: max_size_cycle + preprocessing: + max_seq_length: 128 + resolution: 128 + center_crop: false + random_flip: false + params: + num_workers: 32 +optimizer: + name: adamw + params: + learning_rate: 1.0e-05 + scale_lr: false + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1.0e-08 +lr_scheduler: + scheduler: cosine + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 5000 + min_lr_scale: 0.1 +training: + gradient_accumulation_steps: 4 + noise_type: mask + batch_size_v2t: 4 + batch_size_mmu: 1 + mixed_precision: bf16 + enable_tf32: true + seed: 10086 + max_train_steps: 1000000 + overfit_one_batch: false + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 3 + generation_timesteps: 12 + mmu_coeff: 1.0 + validation_seed: 42 +config: /home/work/AIDAS/MMaDA/configs/mmada_pretraining_v2t.yaml diff --git a/MMaDA/models/__init__.py b/MMaDA/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ca0cec11b309433b26291ba4bb3f99ef98d4b1 --- /dev/null +++ b/MMaDA/models/__init__.py @@ -0,0 +1,4 @@ +from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2 +from .sampling import * +from .modeling_mmada import MMadaModelLM, MMadaConfig +from .modeling_omada import OMadaModelLM, OMadaConfig \ No newline at end of file diff --git a/MMaDA/models/common_modules.py b/MMaDA/models/common_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..ab7017ead91e759a896ab854c8a976038424bd4d --- /dev/null +++ b/MMaDA/models/common_modules.py @@ -0,0 +1,357 @@ +""" +Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34 +""" + +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class DepthToSpaceUpsample(nn.Module): + def __init__( + self, + in_channels, + ): + super().__init__() + conv = nn.Conv2d(in_channels, in_channels * 4, 1) + + self.net = nn.Sequential( + conv, + nn.SiLU(), + Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2), + ) + + self.init_conv_(conv) + + def init_conv_(self, conv): + o, i, h, w = conv.weight.shape + conv_weight = torch.empty(o // 4, i, h, w) + nn.init.kaiming_uniform_(conv_weight) + conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") + + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def forward(self, x): + out = self.net(x) + return out + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def unpack_time(t, batch): + _, c, w, h = t.size() + out = torch.reshape(t, [batch, -1, c, w, h]) + out = rearrange(out, "b t c h w -> b c t h w") + return out + + +def pack_time(t): + out = rearrange(t, "b c t h w -> b t c h w") + _, _, c, w, h = out.size() + return torch.reshape(out, [-1, c, w, h]) + + +class TimeDownsample2x(nn.Module): + def __init__( + self, + dim, + dim_out=None, + kernel_size=3, + ): + super().__init__() + if dim_out is None: + dim_out = dim + self.time_causal_padding = (kernel_size - 1, 0) + self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2) + + def forward(self, x): + x = rearrange(x, "b c t h w -> b h w c t") + b, h, w, c, t = x.size() + x = torch.reshape(x, [-1, c, t]) + + x = F.pad(x, self.time_causal_padding) + out = self.conv(x) + + out = torch.reshape(out, [b, h, w, c, t]) + out = rearrange(out, "b h w c t -> b c t h w") + out = rearrange(out, "b h w c t -> b c t h w") + return out + + +class TimeUpsample2x(nn.Module): + def __init__(self, dim, dim_out=None): + super().__init__() + if dim_out is None: + dim_out = dim + conv = nn.Conv1d(dim, dim_out * 2, 1) + + self.net = nn.Sequential( + nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2) + ) + + self.init_conv_(conv) + + def init_conv_(self, conv): + o, i, t = conv.weight.shape + conv_weight = torch.empty(o // 2, i, t) + nn.init.kaiming_uniform_(conv_weight) + conv_weight = repeat(conv_weight, "o ... -> (o 2) ...") + + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def forward(self, x): + x = rearrange(x, "b c t h w -> b h w c t") + b, h, w, c, t = x.size() + x = torch.reshape(x, [-1, c, t]) + + out = self.net(x) + out = out[:, :, 1:].contiguous() + + out = torch.reshape(out, [b, h, w, c, t]) + out = rearrange(out, "b h w c t -> b c t h w") + return out + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class TimeAttention(AttnBlock): + def forward(self, x, *args, **kwargs): + x = rearrange(x, "b c t h w -> b h w t c") + b, h, w, t, c = x.size() + x = torch.reshape(x, (-1, t, c)) + + x = super().forward(x, *args, **kwargs) + + x = torch.reshape(x, [b, h, w, t, c]) + return rearrange(x, "b h w t c -> b c t h w") + + +class Residual(nn.Module): + def __init__(self, fn: nn.Module): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode="constant", + **kwargs + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.time_pad = time_pad + self.time_causal_padding = ( + width_pad, + width_pad, + height_pad, + height_pad, + time_pad, + 0, + ) + + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = nn.Conv3d( + chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs + ) + + def forward(self, x): + pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant" + + x = F.pad(x, self.time_causal_padding, mode=pad_mode) + return self.conv(x) + + +def ResnetBlockCausal3D( + dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant" +): + net = nn.Sequential( + Normalize(dim), + nn.SiLU(), + CausalConv3d(dim, dim, kernel_size, pad_mode), + Normalize(dim), + nn.SiLU(), + CausalConv3d(dim, dim, kernel_size, pad_mode), + ) + return Residual(net) + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512 + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + else: + self.temb_proj = None + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h diff --git a/MMaDA/models/configuration_emova_speech_tokenizer.py b/MMaDA/models/configuration_emova_speech_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d591e1b3486b6e1b03651a863cb0060447be5fb5 --- /dev/null +++ b/MMaDA/models/configuration_emova_speech_tokenizer.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2024 The EMOVA team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" EMOVASpeechTokenizer model configuration """ + +import copy +from typing import List + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +U2S_STYLES = [ + 'gender-female_emotion-angry_speed-fast_pitch-high', 'gender-female_emotion-angry_speed-fast_pitch-low', 'gender-female_emotion-angry_speed-fast_pitch-normal', + 'gender-female_emotion-angry_speed-normal_pitch-high', 'gender-female_emotion-angry_speed-normal_pitch-low', 'gender-female_emotion-angry_speed-normal_pitch-normal', + 'gender-female_emotion-angry_speed-slow_pitch-high', 'gender-female_emotion-angry_speed-slow_pitch-low', 'gender-female_emotion-angry_speed-slow_pitch-normal', + 'gender-female_emotion-disgusted_speed-fast_pitch-high', 'gender-female_emotion-disgusted_speed-fast_pitch-low', 'gender-female_emotion-disgusted_speed-fast_pitch-normal', + 'gender-female_emotion-disgusted_speed-normal_pitch-high', 'gender-female_emotion-disgusted_speed-normal_pitch-low', 'gender-female_emotion-disgusted_speed-normal_pitch-normal', + 'gender-female_emotion-disgusted_speed-slow_pitch-high', 'gender-female_emotion-disgusted_speed-slow_pitch-low', 'gender-female_emotion-disgusted_speed-slow_pitch-normal', + 'gender-female_emotion-fearful_speed-fast_pitch-high', 'gender-female_emotion-fearful_speed-fast_pitch-low', 'gender-female_emotion-fearful_speed-fast_pitch-normal', + 'gender-female_emotion-fearful_speed-normal_pitch-high', 'gender-female_emotion-fearful_speed-normal_pitch-low', 'gender-female_emotion-fearful_speed-normal_pitch-normal', + 'gender-female_emotion-fearful_speed-slow_pitch-high', 'gender-female_emotion-fearful_speed-slow_pitch-low', 'gender-female_emotion-fearful_speed-slow_pitch-normal', + 'gender-female_emotion-happy_speed-fast_pitch-high', 'gender-female_emotion-happy_speed-fast_pitch-low', 'gender-female_emotion-happy_speed-fast_pitch-normal', + 'gender-female_emotion-happy_speed-normal_pitch-high', 'gender-female_emotion-happy_speed-normal_pitch-low', 'gender-female_emotion-happy_speed-normal_pitch-normal', + 'gender-female_emotion-happy_speed-slow_pitch-high', 'gender-female_emotion-happy_speed-slow_pitch-low', 'gender-female_emotion-happy_speed-slow_pitch-normal', + 'gender-female_emotion-neutral_speed-fast_pitch-high', 'gender-female_emotion-neutral_speed-fast_pitch-low', 'gender-female_emotion-neutral_speed-fast_pitch-normal', + 'gender-female_emotion-neutral_speed-normal_pitch-high', 'gender-female_emotion-neutral_speed-normal_pitch-low', 'gender-female_emotion-neutral_speed-normal_pitch-normal', + 'gender-female_emotion-neutral_speed-slow_pitch-high', 'gender-female_emotion-neutral_speed-slow_pitch-low', 'gender-female_emotion-neutral_speed-slow_pitch-normal', + 'gender-female_emotion-sad_speed-fast_pitch-high', 'gender-female_emotion-sad_speed-fast_pitch-low', 'gender-female_emotion-sad_speed-fast_pitch-normal', + 'gender-female_emotion-sad_speed-normal_pitch-high', 'gender-female_emotion-sad_speed-normal_pitch-low', 'gender-female_emotion-sad_speed-normal_pitch-normal', + 'gender-female_emotion-sad_speed-slow_pitch-high', 'gender-female_emotion-sad_speed-slow_pitch-low', 'gender-female_emotion-sad_speed-slow_pitch-normal', + 'gender-female_emotion-surprised_speed-fast_pitch-high', 'gender-female_emotion-surprised_speed-fast_pitch-low', 'gender-female_emotion-surprised_speed-fast_pitch-normal', + 'gender-female_emotion-surprised_speed-normal_pitch-high', 'gender-female_emotion-surprised_speed-normal_pitch-low', 'gender-female_emotion-surprised_speed-normal_pitch-normal', + 'gender-female_emotion-surprised_speed-slow_pitch-high', 'gender-female_emotion-surprised_speed-slow_pitch-low', 'gender-female_emotion-surprised_speed-slow_pitch-normal', + 'gender-male_emotion-angry_speed-fast_pitch-high', 'gender-male_emotion-angry_speed-fast_pitch-low', 'gender-male_emotion-angry_speed-fast_pitch-normal', + 'gender-male_emotion-angry_speed-normal_pitch-high', 'gender-male_emotion-angry_speed-normal_pitch-low', 'gender-male_emotion-angry_speed-normal_pitch-normal', + 'gender-male_emotion-angry_speed-slow_pitch-high', 'gender-male_emotion-angry_speed-slow_pitch-low', 'gender-male_emotion-angry_speed-slow_pitch-normal', + 'gender-male_emotion-disgusted_speed-fast_pitch-high', 'gender-male_emotion-disgusted_speed-fast_pitch-low', 'gender-male_emotion-disgusted_speed-fast_pitch-normal', + 'gender-male_emotion-disgusted_speed-normal_pitch-high', 'gender-male_emotion-disgusted_speed-normal_pitch-low', 'gender-male_emotion-disgusted_speed-normal_pitch-normal', + 'gender-male_emotion-disgusted_speed-slow_pitch-high', 'gender-male_emotion-disgusted_speed-slow_pitch-low', 'gender-male_emotion-disgusted_speed-slow_pitch-normal', + 'gender-male_emotion-fearful_speed-fast_pitch-high', 'gender-male_emotion-fearful_speed-fast_pitch-low', 'gender-male_emotion-fearful_speed-fast_pitch-normal', + 'gender-male_emotion-fearful_speed-normal_pitch-high', 'gender-male_emotion-fearful_speed-normal_pitch-low', 'gender-male_emotion-fearful_speed-normal_pitch-normal', + 'gender-male_emotion-fearful_speed-slow_pitch-high', 'gender-male_emotion-fearful_speed-slow_pitch-low', 'gender-male_emotion-fearful_speed-slow_pitch-normal', + 'gender-male_emotion-happy_speed-fast_pitch-high', 'gender-male_emotion-happy_speed-fast_pitch-low', 'gender-male_emotion-happy_speed-fast_pitch-normal', + 'gender-male_emotion-happy_speed-normal_pitch-high', 'gender-male_emotion-happy_speed-normal_pitch-low', 'gender-male_emotion-happy_speed-normal_pitch-normal', + 'gender-male_emotion-happy_speed-slow_pitch-high', 'gender-male_emotion-happy_speed-slow_pitch-low', 'gender-male_emotion-happy_speed-slow_pitch-normal', + 'gender-male_emotion-neutral_speed-fast_pitch-high', 'gender-male_emotion-neutral_speed-fast_pitch-low', 'gender-male_emotion-neutral_speed-fast_pitch-normal', + 'gender-male_emotion-neutral_speed-normal_pitch-high', 'gender-male_emotion-neutral_speed-normal_pitch-low', 'gender-male_emotion-neutral_speed-normal_pitch-normal', + 'gender-male_emotion-neutral_speed-slow_pitch-high', 'gender-male_emotion-neutral_speed-slow_pitch-low', 'gender-male_emotion-neutral_speed-slow_pitch-normal', + 'gender-male_emotion-sad_speed-fast_pitch-high', 'gender-male_emotion-sad_speed-fast_pitch-low', 'gender-male_emotion-sad_speed-fast_pitch-normal', + 'gender-male_emotion-sad_speed-normal_pitch-high', 'gender-male_emotion-sad_speed-normal_pitch-low', 'gender-male_emotion-sad_speed-normal_pitch-normal', + 'gender-male_emotion-sad_speed-slow_pitch-high', 'gender-male_emotion-sad_speed-slow_pitch-low', 'gender-male_emotion-sad_speed-slow_pitch-normal', + 'gender-male_emotion-surprised_speed-fast_pitch-high', 'gender-male_emotion-surprised_speed-fast_pitch-low', 'gender-male_emotion-surprised_speed-fast_pitch-normal', + 'gender-male_emotion-surprised_speed-normal_pitch-high', 'gender-male_emotion-surprised_speed-normal_pitch-low', 'gender-male_emotion-surprised_speed-normal_pitch-normal', + 'gender-male_emotion-surprised_speed-slow_pitch-high', 'gender-male_emotion-surprised_speed-slow_pitch-low', 'gender-male_emotion-surprised_speed-slow_pitch-normal' +] + +class EMOVASpeechTokenizerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EMOVASpeechTokenizer`]. It is used to instantiate + a EMOVASpeechTokenizer model especially designed for training the EMOVA (https://arxiv.org/abs/2409.18042) + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a configuration to the speech tokenizer model presented in EMOVA paper. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + s2u_unit_type (`str`, defaults to `40ms_multilingual_8888`): + Unit type to specify model configurations for the speech-to-unit (S2U) encoder. Detailed configs will be found accordingly. + u2s_unit_type (`str`, defaults to `40ms_multilingual_8888_xujing_cosyvoice_FT`): + Unit type to specify model configurations for the unit-to-speech (U2S) decoder. Detailed configs will be found accordingly. + u2s_num_styles, u2s_dim_styles (`int`, defaults to 126 and 256): + Size of the style embedding matrix. + ```python + >>> from transformers import EMOVASpeechTokenizerConfig, EMOVASpeechTokenizer + >>> # Initializing a EMOVA speech tokenizer configuration + >>> configuration = EMOVASpeechTokenizerConfig() + >>> # Initializing a model from the EMOVA speech tokenizer configuration + >>> model = EMOVASpeechTokenizer(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "EMOVASpeechTokenizer" + + def __init__( + self, + s2u_unit_type="40ms_multilingual_8888", + u2s_unit_type="40ms_multilingual_8888_xujing_cosyvoice_FT", + u2s_num_styles=126, + u2s_dim_styles=256, + **kwargs, + ): + super().__init__(**kwargs) + + self.s2u_unit_type = s2u_unit_type + self.u2s_unit_type = u2s_unit_type + self.u2s_num_styles = u2s_num_styles + self.u2s_dim_styles = u2s_dim_styles + self.u2s_style2idx = {each:i for i, each in enumerate(U2S_STYLES)} \ No newline at end of file diff --git a/MMaDA/models/configuration_llada.py b/MMaDA/models/configuration_llada.py new file mode 100644 index 0000000000000000000000000000000000000000..3556bdac0bc0b06c6ec606830a28e9b3e6aeed56 --- /dev/null +++ b/MMaDA/models/configuration_llada.py @@ -0,0 +1,463 @@ +""" +LLaDA configuration +""" +from transformers import AutoConfig, PretrainedConfig + +from enum import Enum +from os import PathLike +from typing import Union +from dataclasses import asdict, dataclass, field +from glob import glob +from pathlib import Path +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) + + +__all__ = [ + "ActivationType", + "ActivationCheckpointingStrategy", + "BlockType", + "LayerNormType", + "InitFnType", + "ModelConfig", +] + +PathOrStr = Union[str, PathLike] + + +class StrEnum(str, Enum): + """ + This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. + We include this here for compatibility with older version of Python. + """ + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"'{str(self)}'" + + +class LayerNormType(StrEnum): + default = "default" + """ + The default LayerNorm implementation, equivalent to PyTorch's built-in version. + """ + + low_precision = "low_precision" + """ + A low-precision version of the default LayerNorm. + """ + + rms = "rms" + """ + An RMSNorm implementation. When using ``torch.compile`` this is + probably the fastest implementation. + """ + + gemma_rms = "gemma_rms" + """ + An RMSNorm implementation by gemmma. When using ``torch.compile`` this is + probably the fastest implementation. + """ + + amd_compatible = "amd_compatible" + """ + LayerNorm implemented manually to work around an issue with ROCm. + """ + + +class ActivationType(StrEnum): + gelu = "gelu" + relu = "relu" + silu = "silu" + swiglu = "swiglu" + + +class BlockType(StrEnum): + sequential = "sequential" + parallel = "parallel" + + llama = "llama" + """ + A block similar to the sequential block with slightly different + implementations of operations like attention to imitate the behavior of Llama. + """ + + +class InitFnType(StrEnum): + mitchell = "mitchell" + """ + The strategy suggested to us by Mitchell Wortsman from UW. + This uses a truncated normal distribution with an adaptive standard deviation that depends + on the size of the weights as well as the depth of the layer. + """ + + normal = "normal" + """ + All weights are initialized from the same normal distribution. + """ + + kaiming_normal = "kaiming_normal" + """ + All weights are initialized with the Kaiming method from a normal distribution. + Note this currently won't work with FSDP. + """ + + fan_in = "fan_in" + """ + "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` + is the input dimensionality of the kernel. + """ + + full_megatron = "full_megatron" + """ + This is what metaseq calls "full megatron init". It is the init used for Llama 2. + """ + + +@dataclass +class ModelConfig(): + """ + LLaDA (model) configuration. + """ + + # Note that the defaults for these attributes are equivalent to the base GPT2 model. + + d_model: int = 768 + """ + The hidden size of the model. + """ + + n_heads: int = 12 + """ + The number of self-attention heads. + """ + + n_kv_heads: Optional[int] = None + """ + The number of heads to use for keys and values. Defaults to `n_heads`. + Set this to ``None`` or ``n_heads`` for normal multi-head attention. + Set this to 1 for multi-query attention. + Set it to some in-between value for Llama2-style grouped query attention. + """ + + n_layers: int = 12 + """ + The number of layers/blocks. + """ + + mlp_ratio: int = 4 + """ + The ratio of the inner MLP dimensionality to ``d_model``. + This is only used when ``mlp_hidden_size`` is not set. + """ + + mlp_hidden_size: Optional[int] = None + """ + Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. + """ + + activation_type: ActivationType = ActivationType.swiglu + """ + The activation function to use within the MLP layers. + """ + + block_type: BlockType = BlockType.sequential + """ + The transformer block implementation. + """ + + block_group_size: int = 1 + """ + The number of blocks to group together into a single parent block. + This has no affect on the number of parameters in the model and is only used to wrap groups + of blocks together with a single FSDP wrapper during training. + """ + + alibi: bool = False + """ + If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. + """ + + alibi_bias_max: float = 8.0 + """ + Maximum absolute value of ALiBi bias. + """ + + rope: bool = False + """ + Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. + """ + + rope_full_precision: bool = True + """ + If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, + apply RoPE at the precision of the input. + """ + + flash_attention: bool = False + """ + If ``True``, use ``FlashAttention``. + """ + + attention_dropout: float = 0.1 + """ + The dropout probability within the attention modules. + """ + + multi_query_attention: Optional[bool] = None + """ + Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters + and is more efficient during inference. + """ + + attention_layer_norm: bool = False + """ + Apply layer norm to the keys and queries within the attention mechanism. + This can help stabilize training. + """ + + residual_dropout: float = 0.1 + """ + The dropout probability for the MLP and attention output within each block. + """ + + embedding_dropout: float = 0.1 + """ + The dropout probability for embeddings. + """ + + input_emb_norm: bool = False + """ + An input hidden_states norm implementation by gemmma. + """ + + layer_norm_type: LayerNormType = LayerNormType.default + """ + The layernorm implementation to use. + """ + + layer_norm_with_affine: bool = True + """ + Whether to include bias and weight parameters for the layer norms. + This only affects layer norms that are immediately followed by a linear layer in the forward pass, + so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` + to ``False``. + """ + + rms_norm_eps: float = 1e-05 + """ + The rms layernorm eps param. + """ + + attention_layer_norm_with_affine: bool = True + """ + Toggle affine transform for the QK norms. + """ + + max_sequence_length: int = 1024 + """ + The maximum input sequence length supported by the model. + """ + + rope_theta: float = 10000.0 + """ + The rope base param. + """ + + include_qkv_bias: Optional[bool] = False + """ + Whether or not to include bias parameters in qkv linear layers. + """ + + include_bias: bool = False + """ + Whether or not to include bias parameters in linear layers. + In PaLM, they got rid of all bias terms because they found that large + models tend to have near 0 bias terms anyway. + """ + + bias_for_layer_norm: Optional[bool] = None + """ + Whether or not to include bias parameters in layer norm. + This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in + layer norm. + When this is None (the default), it inherits the setting from include_bias. + """ + + scale_logits: bool = False + """ + If ``True``, scale the output logits by ``1 / sqrt(d_model)``. + """ + + vocab_size: int = 50257 + """ + Vocabulary size of the model. + """ + + embedding_size: Optional[int] = 50304 + """ + The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default + to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the + next multiple of 128 that's greater than ``vocab_size`` can improve throughput + substantially. + """ + + weight_tying: bool = True + """ + Whether to tie output linear weights to the input embedding. + """ + + eos_token_id: int = 50256 + """ + The ID of the end-of-sentence special token. + """ + + pad_token_id: int = 50256 + """ + The ID of the token to use for padding. Defaults to the ID of the EOS token. + """ + + mask_token_id: Optional[int] = 50256 + """ + The ID of the token to use for mask token. Defaults to the ID of the EOS token. + """ + + init_device: Optional[str] = None + """ + The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". + """ + + init_fn: InitFnType = InitFnType.normal + """ + The weight initialization strategy. + """ + + init_std: float = 0.02 + """ + The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such + as "normal". + """ + + init_cutoff_factor: Optional[float] = None + """ + A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such + as "normal". Setting this to None means values are not cutoff. + """ + + precision: Optional[str] = None + """ + Precision used to train/evaluate with. You shouldn't set this directly. + See :data:`TrainConfig.precision` instead. + """ + + @property + def effective_n_kv_heads(self) -> int: + if self.n_kv_heads is None: + if self.multi_query_attention is True: + return 1 + else: + return self.n_heads + else: + if self.multi_query_attention is None: + return self.n_kv_heads + if self.multi_query_attention: + n_kv_heads_should_be = 1 + else: + n_kv_heads_should_be = self.n_heads + if self.n_kv_heads == n_kv_heads_should_be: + return n_kv_heads_should_be + else: + raise Exception( + "You can't set `multi_query_attention` and `n_kv_heads` at the same time." + ) + +class ActivationCheckpointingStrategy(StrEnum): + whole_layer = "whole_layer" + """ + Checkpoint every transformer layer. + """ + + one_in_two = "one_in_two" + """ + Checkpoint one in two transformer layers. + """ + + one_in_three = "one_in_three" + """ + Checkpoint one in three transformer layers. + """ + + one_in_four = "one_in_four" + """ + Checkpoint one in four transformer layers. + """ + + two_in_three = "two_in_three" + """ + Checkpoint two out of every three transformer layers. + """ + + three_in_four = "three_in_four" + """ + Checkpoint three out of four of every transformer layers. + """ + + four_in_five = "four_in_five" + """ + Checkpoint four out of five of every transformer layers. + """ + + nine_in_ten = "nine_in_ten" + """ + Checkpoint nine out of ten of every transformer layers. + """ + + fine_grained = "fine_grained" + """ + Focus checkpointing on where it is cheap to recompute and saves most memory. + """ + + +class LLaDAConfig(PretrainedConfig): + model_type = "llada" + keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm + + def __init__(self, use_cache: bool = False, **kwargs): + model_config = ModelConfig() + all_kwargs = model_config.__dict__ + all_kwargs.update(kwargs) + all_kwargs.update({"use_cache": use_cache}) + all_kwargs.update( + { + "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"]) + } + ) + super().__init__(**all_kwargs) + + @property + def num_attention_heads(self): + return self.n_heads + + @property + def num_hidden_layers(self): + return self.n_layers + + @property + def hidden_size(self): + return self.d_model + + +# Register the config class so that it is available for transformer pipelines, auto-loading etc. +AutoConfig.register("llada", LLaDAConfig) diff --git a/MMaDA/models/logging.py b/MMaDA/models/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..65814a82380e47e54434c4be97026141772f7298 --- /dev/null +++ b/MMaDA/models/logging.py @@ -0,0 +1,338 @@ +# coding=utf-8 +# Copyright 2023 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Optional + +from tqdm import auto as tqdm_lib + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level(): + """ + If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("muse_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom muse module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the šŸ¤— muse' root logger as an int. + + Returns: + `int`: The logging level. + + + + šŸ¤— muse has following logging levels: + + - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL` + - 40: `muse.logging.ERROR` + - 30: `muse.logging.WARNING` or `muse.logging.WARN` + - 20: `muse.logging.INFO` + - 10: `muse.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the šŸ¤— muse' root logger. + + Args: + verbosity (`int`): + Logging level, e.g., one of: + + - `muse.logging.CRITICAL` or `muse.logging.FATAL` + - `muse.logging.ERROR` + - `muse.logging.WARNING` or `muse.logging.WARN` + - `muse.logging.INFO` + - `muse.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the HuggingFace muse' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the HuggingFace muse' root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace muse' root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace muse' root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for HuggingFace muse' loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/MMaDA/models/lr_schedulers.py b/MMaDA/models/lr_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..082002d54adea55f817ae7511041787327f55a4d --- /dev/null +++ b/MMaDA/models/lr_schedulers.py @@ -0,0 +1,302 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for diffusion models.""" + +import math +from enum import Enum +from typing import Optional, Union + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from .logging import get_logger + +logger = get_logger(__name__) + + +class SchedulerType(Enum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + + +def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) + + +def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, min_lr_scale: float = 0.0 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_periods (`float`, *optional*, defaults to 0.5): + The number of periods of the cosine function in a schedule (the default is to just decrease from the max + value to 0 following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + # def lr_lambda(current_step): + # if current_step < num_warmup_steps: + # return float(current_step) / float(max(1, num_warmup_steps)) + # progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + # return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + # return LambdaLR(optimizer, lr_lambda, last_epoch) + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * 2.0 * num_cycles * progress)) + return min_lr_scale + (1.0 - min_lr_scale) * cosine_decay + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, + min_lr_scale: float = 0.0 +): + """ + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, min_lr_scale=min_lr_scale + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) diff --git a/MMaDA/models/misc.py b/MMaDA/models/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..eccc28fb05df23e2d3af6038fd2455ae242fb7bf --- /dev/null +++ b/MMaDA/models/misc.py @@ -0,0 +1,53 @@ +from omegaconf import OmegaConf +import torch +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, +) +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker + + +def broadcast(tensor, src=0): + if not _distributed_available(): + return tensor + else: + torch.distributed.broadcast(tensor, src=src) + return tensor + +def _distributed_available(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword + if '--local-rank' in cfg: + del cfg['--local-rank'] + # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword + scfg = OmegaConf.structured(fields(**cfg)) + return scfg \ No newline at end of file diff --git a/MMaDA/models/modeling_emova_speech_tokenizer.py b/MMaDA/models/modeling_emova_speech_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..daf40e45929f78fa484ad07eb07f6a0951be5cbd --- /dev/null +++ b/MMaDA/models/modeling_emova_speech_tokenizer.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Copyright 2024 The EMOVA team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" EMOVASpeechTokenizer model """ + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.modeling_utils import PreTrainedModel + +try: + from emova_speech_tokenizer.speech_utils import get_S2U_ckpt_config_path, load_config, VQCTCFinetuneModel, s2u_extract_unit_demo + from emova_speech_tokenizer.speech_utils import get_U2S_config_checkpoint_file, load_U2S_config, SynthesizerTrn, synthesis +except: + raise ImportError('Dependencies of emova speech tokenizer are not installed properly. Check https://github.com/emova-ollm/EMOVA_speech_tokenizer#installation for detailed instructions.') + +from .configuration_emova_speech_tokenizer import EMOVASpeechTokenizerConfig + +class EMOVASpeechTokenizer(PreTrainedModel): + config_class = EMOVASpeechTokenizerConfig + base_model_prefix = "emova_speech_tokenizer" + + def __init__(self, config: EMOVASpeechTokenizerConfig): + super().__init__(config) + self.config = config + + # s2u encoder configs + _, S2U_config_path = get_S2U_ckpt_config_path(config.s2u_unit_type) + s2u_cfg = load_config(config=S2U_config_path) + s2u_cfg.model.pretrain_chkpt_path = None + + # u2s decoder configs + U2S_config_file, _ = get_U2S_config_checkpoint_file(config.u2s_unit_type) + u2s_cfg = load_U2S_config(U2S_config_file) + + # construct models + self.s2u_config = s2u_cfg.model + self.u2s_config = u2s_cfg + self.encoder = VQCTCFinetuneModel(s2u_cfg.model, trainer=None) + self.decoder = SynthesizerTrn( + u2s_cfg.num_symbols, + u2s_cfg.data.filter_length // 2 + 1, + u2s_cfg.train.segment_size // u2s_cfg.data.hop_length, + n_speakers=u2s_cfg.data.n_speakers, + **u2s_cfg.model + ) + self.style_embedding = nn.Embedding(config.u2s_num_styles, config.u2s_dim_styles) + + @property + def device(self): + return next(self.encoder.parameters()).device + + @property + def dtype(self): + return next(self.encoder.parameters()).dtype + + def encode(self, wav_file): + speech_unit = s2u_extract_unit_demo(self.encoder, wav_file, model_name='SPIRAL-FSQ-CTC', reduced=True) + unit_numbers = speech_unit.replace('<|speech_', '').replace('|>', ' ').strip() + unit_ids = [int(unit) for unit in unit_numbers.split(" ")] + return torch.LongTensor(unit_ids).unsqueeze(0) + + def decode(self, speech_unit, condition=None, output_wav_file='output.wav'): + content_unit = speech_unit.replace('<|speech_', '').replace('|>', ' ').strip() + style_centroid_embedding = self.style_embedding(torch.LongTensor([self.config.u2s_style2idx[condition]]).to(self.device)).unsqueeze(-1) if condition else None + audio = synthesis(content_unit, style_centroid_embedding, self.u2s_config, self.decoder, output_wav_file) + return audio \ No newline at end of file diff --git a/MMaDA/models/modeling_llada.py b/MMaDA/models/modeling_llada.py new file mode 100644 index 0000000000000000000000000000000000000000..6f90494981587eabacd3c698ba7f59508a198b03 --- /dev/null +++ b/MMaDA/models/modeling_llada.py @@ -0,0 +1,1505 @@ +from __future__ import annotations + +import logging +import math +import sys +from abc import abstractmethod +from collections import defaultdict +from functools import partial +from typing import ( + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, +cast, +) +from dataclasses import fields +from typing import List, Optional, Tuple, Union + +import torch +import torch.backends.cuda +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.auto import AutoModel +from transformers.cache_utils import Cache + +from .configuration_llada import ( + LLaDAConfig, + StrEnum, + InitFnType, + ActivationType, + BlockType, + LayerNormType, + ModelConfig, + ActivationCheckpointingStrategy, +) + +if sys.version_info.minor > 8: + from collections.abc import MutableMapping +elif sys.version_info.minor == 8: + from typing import MutableMapping +else: + raise SystemExit("This script supports Python 3.8 or higher") + +__all__ = [ + "LayerNormBase", + "LayerNorm", + "RMSLayerNorm", + "GemmaRMSLayerNorm", + "RotaryEmbedding", + "Activation", + "GELU", + "ReLU", + "SwiGLU", + "LLaDABlock", + "LLaDASequentialBlock", + "LLaDAModel", + "LLaDAOutput", + "LLaDAGenerateOutput", +] + + +log = logging.getLogger(__name__) + + +class ModuleType(StrEnum): + in_module = "in" + out_module = "out" + emb = "emb" + final_out = "final_out" + + +def init_weights( + config: ModelConfig, + module: Union[nn.Linear, nn.Embedding], + d: Optional[int] = None, + layer_id: Optional[int] = None, + std_factor: float = 1.0, + type_of_module: Optional[ModuleType] = None, +) -> None: + """ + Initialize weights of a linear or embedding module. + + :param config: The model config. + :param module: The linear or embedding submodule to initialize. + :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions + for fused layers. + :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by + ``1 / sqrt(2 * (layer_id + 1))``. + """ + d = d if d is not None else config.d_model + if config.init_fn == InitFnType.normal: + std = config.init_std * std_factor + if config.init_cutoff_factor is not None: + cutoff_value = config.init_cutoff_factor * std + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) + else: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif config.init_fn == InitFnType.mitchell: + std = std_factor / math.sqrt(d) + if layer_id is not None: + std = std / math.sqrt(2 * (layer_id + 1)) + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) + elif config.init_fn == InitFnType.kaiming_normal: + nn.init.kaiming_normal_(module.weight, nonlinearity="relu") + elif config.init_fn == InitFnType.fan_in: + std = std_factor / math.sqrt(d) + nn.init.normal_(module.weight, mean=0.0, std=std) + elif config.init_fn == InitFnType.full_megatron: + if type_of_module is None: + raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") + + cutoff_factor = config.init_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + if type_of_module == ModuleType.in_module: + # for att_proj (same as QKV), ff_proj + std = config.init_std + elif type_of_module == ModuleType.out_module: + # for attn_out, ff_out + std = config.init_std / math.sqrt(2.0 * config.n_layers) + elif type_of_module == ModuleType.emb: + # positional embeddings (wpe) + # token embeddings (wte) + std = config.init_std + elif type_of_module == ModuleType.final_out: + # final output (ff_out) + std = config.d_model**-0.5 + else: + raise RuntimeError(f"Unknown module type '{type_of_module}'") + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + else: + raise NotImplementedError(config.init_fn) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False): + with torch.no_grad(): + module.weight.div_(math.sqrt(2 * config.n_layers)) + + +def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): + """ + Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` + is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. + """ + if check_neg_inf: + x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) + if check_pos_inf: + x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) + + +def activation_checkpoint_function(cfg: ModelConfig): + preserve_rng_state = ( + (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0) + ) + from torch.utils.checkpoint import checkpoint + + return partial( + checkpoint, + preserve_rng_state=preserve_rng_state, + use_reentrant=False, + ) + + +class BufferCache(dict, MutableMapping[str, torch.Tensor]): + """ + Cache for attention biases and other things that would normally be stored as buffers. + We avoid using buffers because we've run into various issues doing so with FSDP. + In general it appears the way FSDP handles buffers is not well-defined. + It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid + since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into + NaNs when they're synchronized due to casting or some other issue. + """ + + +def _non_meta_init_device(config: ModelConfig) -> torch.device: + if config.init_device is not None and config.init_device != "meta": + return torch.device(config.init_device) + else: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Dropout(nn.Dropout): + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.p == 0.0: + return input + else: + return F.dropout(input, self.p, self.training, self.inplace) + + +class LayerNormBase(nn.Module): + def __init__( + self, + config: ModelConfig, + *, + size: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + eps: float = 1e-05, + ): + super().__init__() + self.config = config + self.eps = eps + self.normalized_shape = (size or config.d_model,) + if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine): + self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device)) + use_bias = self.config.bias_for_layer_norm + if use_bias is None: + use_bias = self.config.include_bias + if use_bias: + self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device)) + else: + self.register_parameter("bias", None) + else: + self.register_parameter("bias", None) + self.register_parameter("weight", None) + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @classmethod + def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase: + if config.layer_norm_type == LayerNormType.default: + return LayerNorm(config, size=size, low_precision=False, **kwargs) + elif config.layer_norm_type == LayerNormType.low_precision: + return LayerNorm(config, size=size, low_precision=True, **kwargs) + elif config.layer_norm_type == LayerNormType.rms: + return RMSLayerNorm(config, size=size, **kwargs) + elif config.layer_norm_type == LayerNormType.gemma_rms: + return GemmaRMSLayerNorm(config, size=size, **kwargs) + else: + raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'") + + def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if tensor.device.type == "cuda" and torch.is_autocast_enabled(): + return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype()) + elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype()) + else: + return tensor + + def reset_parameters(self): + if self.weight is not None: + torch.nn.init.ones_(self.weight) # type: ignore + if self.bias is not None: + torch.nn.init.zeros_(self.bias) # type: ignore + + +class LayerNorm(LayerNormBase): + """ + The default :class:`LayerNorm` implementation which can optionally run in low precision. + """ + + def __init__( + self, + config: ModelConfig, + size: Optional[int] = None, + low_precision: bool = False, + elementwise_affine: Optional[bool] = None, + eps: float = 1e-05, + ): + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) + self.low_precision = low_precision + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.low_precision: + module_device = x.device + downcast_x = self._cast_if_autocast_enabled(x) + downcast_weight = ( + self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + ) + downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + with torch.autocast(enabled=False, device_type=module_device.type): + return F.layer_norm( + downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps + ) + else: + return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) + + +class RMSLayerNorm(LayerNormBase): + """ + RMS layer norm, a simplified :class:`LayerNorm` implementation + """ + + def __init__( + self, + config: ModelConfig, + size: Optional[int] = None, + elementwise_affine: Optional[bool] = None, + eps: float = 1e-5, + ): + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with torch.autocast(enabled=False, device_type=x.device.type): + og_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = x.to(og_dtype) + + if self.weight is not None: + if self.bias is not None: + return self.weight * x + self.bias + else: + return self.weight * x + else: + return x + + +class GemmaRMSLayerNorm(LayerNormBase): + """ + Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation + """ + + def __init__( + self, + config: ModelConfig, + size: Optional[int] = None, + elementwise_affine: Optional[bool] = None, + eps: float = 1e-5, + ): + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with torch.autocast(enabled=False, device_type=x.device.type): + og_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = x.to(og_dtype) + + if self.weight is not None: + if self.bias is not None: + return x * (1 + self.weight) + self.bias + else: + return x * (1 + self.weight) + else: + return x + + +class RotaryEmbedding(nn.Module): + """ + [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864). + """ + + def __init__(self, config: ModelConfig, cache: BufferCache): + super().__init__() + self.config = config + self.__cache = cache + # Warm up cache. + self.rope_theta = config.rope_theta + self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config)) + + def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + if ( + (pos_sin := self.__cache.get("rope_pos_sin")) is not None + and (pos_cos := self.__cache.get("rope_pos_cos")) is not None + and pos_sin.shape[-2] >= seq_len + and pos_cos.shape[-2] >= seq_len + ): + if pos_sin.device != device: + pos_sin = pos_sin.to(device) + self.__cache["rope_pos_sin"] = pos_sin + if pos_cos.device != device: + pos_cos = pos_cos.to(device) + self.__cache["rope_pos_cos"] = pos_cos + return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :] + + with torch.autocast(device.type, enabled=False): + dim = self.config.d_model // self.config.n_heads + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) + seq = torch.arange(seq_len, device=device, dtype=torch.float) + freqs = einsum("i , j -> i j", seq, inv_freq) + positions = torch.cat((freqs, freqs), dim=-1) + pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :] + self.__cache["rope_pos_sin"] = pos_sin + self.__cache["rope_pos_cos"] = pos_cos + return pos_sin, pos_cos + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + B, nh, T, hs = x.size() + x = x.view(B, nh, T, 2, hs // 2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype) + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.config.rope_full_precision: + q_, k_ = q.float(), k.float() + else: + q_, k_ = q, k + + with torch.autocast(q.device.type, enabled=False): + query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None + pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) + pos_sin = pos_sin.type_as(q_) + pos_cos = pos_cos.type_as(q_) + q_ = self.apply_rotary_pos_emb( + pos_sin[:, :, key_len - query_len : key_len, :], + pos_cos[:, :, key_len - query_len : key_len, :], + q_, + ) + k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_) + return q_.type_as(q), k_.type_as(k) + + +class Activation(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @property + @abstractmethod + def output_multiplier(self) -> float: + raise NotImplementedError + + @classmethod + def build(cls, config: ModelConfig) -> Activation: + if config.activation_type == ActivationType.gelu: + return cast(Activation, GELU(approximate="none")) + elif config.activation_type == ActivationType.relu: + return cast(Activation, ReLU(inplace=False)) + elif config.activation_type == ActivationType.silu: + return cast(Activation, SiLU(inplace=False)) + elif config.activation_type == ActivationType.swiglu: + return SwiGLU(config) + else: + raise NotImplementedError(f"Unknown activation: '{config.activation_type}'") + + +class GELU(nn.GELU): + @property + def output_multiplier(self) -> float: + return 1.0 + + +class ReLU(nn.ReLU): + @property + def output_multiplier(self) -> float: + return 1.0 + +class SiLU(nn.SiLU): + @property + def output_multiplier(self) -> float: + return 1.0 + +class SwiGLU(Activation): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + @property + def output_multiplier(self) -> float: + return 0.5 + + +def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: + att_bias = torch.triu( + torch.ones(seq_len, seq_len, device=device, dtype=torch.float), + diagonal=1, + ) + att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) + return att_bias.view(1, 1, seq_len, seq_len) # type: ignore + + +def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: + if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len: + if causal_bias.device != device: + causal_bias = causal_bias.to(device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + with torch.autocast(device.type, enabled=False): + causal_bias = causal_attention_bias(seq_len, device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + + +def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) + + # shape: (1, 1, seq_len, seq_len) + alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) + alibi_bias.abs_().mul_(-1) + + # shape: (n_heads,) + m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) + m.mul_(config.alibi_bias_max / config.n_heads) + + # shape: (1, n_heads, seq_len, seq_len) + return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore + + +class LLaDABlock(nn.Module): + """ + A base class for transformer block implementations. + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__() + self.layer_id = layer_id + self.config = config + self.hidden_size = ( + config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model + ) + self.__cache = cache + assert config.d_model % config.n_heads == 0 + + self._activation_checkpoint_fn = None + + # Dropout. + self.dropout = Dropout(config.residual_dropout) + + # Layer norms. + self.k_norm: Optional[LayerNormBase] = None + self.q_norm: Optional[LayerNormBase] = None + if config.attention_layer_norm: + self.k_norm = LayerNormBase.build( + config, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, + elementwise_affine=config.attention_layer_norm_with_affine, + ) + self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward output projection. + self.ff_out = nn.Linear( + int(self.act.output_multiplier * self.hidden_size), + config.d_model, + bias=config.include_bias, + device=config.init_device, + ) + self.ff_out._is_residual = True # type: ignore + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config, self.__cache) + + self.flash_attn_func = None + if config.flash_attention: + try: + from flash_attn import flash_attn_func # type: ignore + + self.flash_attn_func = flash_attn_func + except ModuleNotFoundError: + pass + + def reset_parameters(self): + if self.k_norm is not None: + self.k_norm.reset_parameters() + if self.q_norm is not None: + self.q_norm.reset_parameters() + init_weights( + self.config, + self.attn_out, + d=self.config.d_model, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + init_weights( + self.config, + self.ff_out, + d=self.ff_out.in_features, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + if strategy == ActivationCheckpointingStrategy.fine_grained: + self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + else: + self._activation_checkpoint_fn = None + + @classmethod + def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: + target_dtype = input_dtype + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if bias.device.type == "cuda" and torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + target_dtype = torch.get_autocast_cpu_dtype() + if bias.dtype != target_dtype: + bias = bias.to(target_dtype) + ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) + return bias + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> torch.Tensor: + """ + Computes scaled dot product attention on query, key and value tensors, using an optional + attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. + """ + if self.flash_attn_func is not None and attn_mask is None: + r = self.flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False + ) + return r.transpose(1, 2) + else: + # torch's sdpa doesn't support GQA, so we're doing this + assert k.size(1) == v.size(1) + num_kv_heads = k.size(1) + num_q_heads = q.size(1) + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0 + k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + + # Modify: MDM set causal to False, and with no attn_mask. + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + def attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + B, T, C = q.size() # batch size, sequence length, d_model + dtype = k.dtype + + # Optionally apply layer norm to keys and queries. + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm(q).to(dtype=dtype) + k = self.k_norm(k).to(dtype=dtype) + + # Move head forward to be next to the batch dim. + # shape: (B, nh, T, hs) + q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + present = (k, v) if use_cache else None + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + + if self.config.rope: + # Apply rotary embeddings. + q, k = self.rotary_emb(q, k) + + if attention_bias is not None: + # Resize and cast attention bias. + # The current dtype of the attention bias might not match the dtype that the SDP attn function will + # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding + # as down-casting the attention bias to the autocast precision will result in -infs, which will + # cause the SDP attn function to produce NaNs. + attention_bias = self._cast_attn_bias( + attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype + ) + + # Get the attention scores. + # shape: (B, nh, T, hs) + att = self._scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_bias, + dropout_p=0.0 if not self.training else self.config.attention_dropout, + is_causal=False, + ) + + # Re-assemble all head outputs side-by-side. + att = att.transpose(1, 2).contiguous().view(B, T, C) + + # Apply output projection. + return self.attn_out(att), present + + @abstractmethod + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + raise NotImplementedError + + @classmethod + def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock: + if config.block_type == BlockType.sequential: + return LLaDASequentialBlock(layer_id, config, cache) + elif config.block_type == BlockType.llama: + return LLaDALlamaBlock(layer_id, config, cache) + else: + raise NotImplementedError(f"Unknown block type: '{config.block_type}'") + + +class LLaDASequentialBlock(LLaDABlock): + """ + This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__(layer_id, config, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + # Attention input projection. Projects x -> (q, k, v) + head_dim = config.d_model // config.n_heads + self.fused_dims = ( + config.d_model, + config.effective_n_kv_heads * head_dim, + config.effective_n_kv_heads * head_dim, + ) + self.att_proj = nn.Linear( + config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device + ) + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights( + self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + ) + init_weights( + self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + ) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + if self._activation_checkpoint_fn is not None: + q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split( + self.fused_dims, dim=-1 + ) + else: + q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + x = self.ff_proj(x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class LLaDALlamaBlock(LLaDABlock): + """ + This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). This block is similar to `LLaDASequentialBlock` + but some operations have slightly different implementations to imitate the + behavior of Llama. + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__(layer_id, config, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + self.__cache = cache + + # Attention input projection. Projects x -> (q, k, v) + head_dim = config.d_model // config.n_heads + q_proj_out_dim = config.d_model + k_proj_out_dim = config.effective_n_kv_heads * head_dim + v_proj_out_dim = config.effective_n_kv_heads * head_dim + self.q_proj = nn.Linear( + config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device + ) + self.k_proj = nn.Linear( + config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device + ) + self.v_proj = nn.Linear( + config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device + ) + + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + # new add + self.up_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + x_normed = self.attn_norm(x) + q = self.q_proj(x_normed) + k = self.k_proj(x_normed) + v = self.v_proj(x_normed) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + x, x_up = self.ff_proj(x), self.up_proj(x) # new add + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = x * x_up # new add + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class LLaDAOutput(NamedTuple): + logits: torch.FloatTensor + """ + A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities + for the next token *before* normalization via (log) softmax. + """ + + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] + """ + Attention keys and values from each block. + """ + + hidden_states: Optional[Tuple[torch.Tensor]] + """ + Hidden states from each block. + """ + + +class LLaDAGenerateOutput(NamedTuple): + token_ids: torch.LongTensor + """ + The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`. + These do *not* include the original input IDs. + """ + + scores: torch.FloatTensor + """ + The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`. + """ + + +class LLaDABlockGroup(nn.ModuleList): + def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None): + super().__init__(modules) + self.config = config + self.layer_offset = layer_offset + self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None + self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.FloatTensor] = None, + layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + for block_idx, block in enumerate(self): + layer_past = None if layers_past is None else layers_past[block_idx] + block_idx += self.layer_offset + if ( + (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two + and block_idx % 2 == 0 + ) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three + and block_idx % 3 == 0 + ) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four + and block_idx % 4 == 0 + ) + ): + # shape: (batch_size, seq_len, d_model) + x, cache = self._activation_checkpoint_fn( # type: ignore + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + # shape: (batch_size, seq_len, d_model) + x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + return x, attn_key_values + + def reset_parameters(self): + for block in self: + block.reset_parameters() + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + self.activation_checkpointing_strategy = strategy + for block in self: + block.set_activation_checkpointing(strategy) + + +class LLaDAModel(nn.Module): + def __init__(self, config: ModelConfig, init_params: bool = True): + super().__init__() + self.config = config + self.__cache = BufferCache() + + # Validate config. + if self.config.alibi and self.config.flash_attention: + raise Exception("ALiBi is currently not supported with FlashAttention") + + if self.config.alibi and self.config.rope: + raise Exception("ALiBi and RoPE are mutually exclusive") + + # # Temp for T2S infernece + # self.config.embedding_size += 4096 + + if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size: + if self.config.embedding_size < self.config.vocab_size: + print(self.config.embedding_size) + print(self.config.vocab_size) + raise Exception("embedding size should be at least as big as vocab size") + elif self.config.embedding_size % 128 != 0: + import warnings + + warnings.warn( + "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning + ) + + self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None + self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config) + + if not ( + 0 < self.config.block_group_size <= self.config.n_layers + and self.config.n_layers % self.config.block_group_size == 0 + ): + raise Exception("n layers must be divisible by block group size") + + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding( + config.embedding_size or config.vocab_size, config.d_model, device=config.init_device + ), + emb_drop=Dropout(config.embedding_dropout), + ln_f=LayerNorm.build(config), + ) + ) + + blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)] + if self.config.block_group_size > 1: + block_groups = [ + LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size]) + for i in range(0, config.n_layers, config.block_group_size) + ] + self.transformer.update({"block_groups": nn.ModuleList(block_groups)}) + else: + self.transformer.update({"blocks": nn.ModuleList(blocks)}) + + if not (self.config.alibi or self.config.rope): + self.transformer.update( + {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} + ) + if not config.weight_tying: + self.transformer.update( + { + "ff_out": nn.Linear( + config.d_model, + config.embedding_size or config.vocab_size, + bias=config.include_bias, + device=config.init_device, + ) + } + ) + # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights. + if init_params and self.config.init_device != "meta": + self.reset_parameters() + self.__num_fwd_flops: Optional[int] = None + + # Warm up cache. + if self.config.alibi: + get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) + self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + self.activation_checkpointing_strategy = strategy + if self.config.block_group_size != 1: + for block_group in self.transformer.block_groups: + block_group.set_activation_checkpointing(strategy) + else: + for block in self.transformer.blocks: + block.set_activation_checkpointing(strategy) + + @property + def device(self) -> torch.device: + device: torch.device = self.transformer.wte.weight.device # type: ignore + if device.type == "meta": + return _non_meta_init_device(self.config) + else: + return device + + def reset_parameters(self): + log.info("Initializing model parameters...") + # Top-level embeddings / linear layers. + init_weights( + self.config, + self.transformer.wte, # type: ignore + std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0, + type_of_module=ModuleType.emb, + ) + if hasattr(self.transformer, "wpe"): + init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore + + # Top-level layer norm. + self.transformer.ln_f.reset_parameters() # type: ignore + + # Output weights. + if hasattr(self.transformer, "ff_out"): + init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore + + # Let the blocks handle themselves. + if self.config.block_group_size == 1: + for block in self.transformer.blocks: + block.reset_parameters() + else: + for block_group in self.transformer.block_groups: + block_group.reset_parameters() + + def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: + if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[ + -1 + ] >= seq_len: + if alibi_bias.device != device: + alibi_bias = alibi_bias.to(device) + self.__cache["alibi_attention_bias"] = alibi_bias + return alibi_bias + with torch.autocast(device.type, enabled=False): + alibi_bias = alibi_attention_bias(seq_len, self.config, device) + self.__cache["alibi_attention_bias"] = alibi_bias + return alibi_bias + + def forward( + self, + input_ids: torch.LongTensor, + input_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + last_logits_only: bool = False, + output_hidden_states: Optional[bool] = None, + ) -> LLaDAOutput: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input + embeddings. When provided, it is treated as the output of the input embedding layer. + :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates + which input IDs are masked. A `1` value in the mask means that + the corresponding input ID should *not* be ignored. A `0` means + that the corresponding input ID is masked. + + This has the same meaning as the `attention_mask` in HuggingFace's `transformers` + library. + :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, + `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used + to introduce causal or other biases. + + If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` + indicates that the i-th element in the sequence is allowed to attend to the j-th + element in the sequence. + + If the tensor is a float tensor, it will just be added to the attention + scores before the softmax. + + The default is causal, which corresponds to a lower-diagonal byte matrix of ones. + :param past_key_values: Pre-computed keys and values for each attention block. + Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + :param use_cache: If `True`, return key and value tensors for each block. + :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. + This can speed up decoding when you only care about the next token. + """ + # Add Basic MDM Model config check + assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM." + assert self.config.rope, "Rope must be used in Llama-Encoder for MDM." + assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM." + + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + + if past_key_values: + assert len(past_key_values) == self.config.n_layers + + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] + if past_key_values is None: + past_length = 0 + else: + past_length = past_key_values[0][0].size(-2) + + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + # print(f"input_ids: {input_ids}, input_ids.shape: {input_ids.shape}") + # print(f"transformer wte weight shape: {self.transformer.wte.weight.shape}") + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore + + # print(f"xshape: {x.shape}") + + if self.config.input_emb_norm: + x = x * (self.config.d_model**0.5) + + if not (self.config.alibi or self.config.rope): + # Get positional embeddings. + # shape: (1, seq_len) + pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) + # shape: (1, seq_len, d_model) + pos_emb = self.transformer.wpe(pos) # type: ignore + x = pos_emb + x + + # Add input + positional embeddings and apply dropout. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.emb_drop(x) # type: ignore + + # Transform the attention mask into what the blocks expect. + if attention_mask is not None and 0.0 in attention_mask: + # shape: (batch_size, 1, 1, seq_len) + attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] + attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min + else: + attention_mask = None + + # Merge attention mask with attention bias. + if ( + attention_bias is not None + or attention_mask is not None + or self.config.alibi + # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly + # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute + # scores correctly. + or past_key_values is not None + ): + if attention_bias is None and self.config.alibi: + # print(f"get_causal_attention_bias") + attention_bias = get_causal_attention_bias( + self.__cache, past_length + seq_len, x.device + ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) + elif attention_bias is None: + # print(f"get_causal_attention_bias") + attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) + elif attention_bias.dtype in (torch.int8, torch.bool): + # print(f"attention_bias.dtype in (torch.int8, torch.bool)") + attention_bias = attention_bias.to(dtype=torch.float) + attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) + + # Transform to the right shape and data type. + mask_len = seq_len + if attention_mask is not None: + mask_len = attention_mask.shape[-1] + elif past_key_values is not None: + mask_len = past_key_values[0][0].shape[-2] + seq_len + attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) + + # Add in the masking bias. + if attention_mask is not None: + attention_bias = attention_bias + attention_mask + # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. + # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead + # it can produce NaNs. + ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) + + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + + # decoder layers + all_hidden_states = [] + + # Apply blocks one-by-one. + if self.config.block_group_size == 1: + for block_idx, block in enumerate(self.transformer.blocks): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layer_past = None if past_key_values is None else past_key_values[block_idx] + if ( + (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two + and block_idx % 2 == 0 + ) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three + and block_idx % 3 == 0 + ) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four + and block_idx % 4 == 0 + ) + ): + # shape: (batch_size, seq_len, d_model) + x, cache = self._activation_checkpoint_fn( + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + # shape: (batch_size, seq_len, d_model) + x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + else: + for group_idx, block_group in enumerate(self.transformer.block_groups): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layers_past = ( + None + if past_key_values is None + else past_key_values[ + group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size + ] + ) + x, cache = block_group( + x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache + ) + if attn_key_values is not None: + assert cache is not None + attn_key_values.extend(cache) + + if last_logits_only: + # shape: (batch_size, 1, d_model) + x = x[:, -1, :].unsqueeze(1) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + x = self.transformer.ln_f(x) # type: ignore + if output_hidden_states: + # add final hidden state post-final-layernorm, following HuggingFace's convention + all_hidden_states.append(x) + + # Get logits. + # shape: (batch_size, seq_len or 1, vocab_size) + if self.config.weight_tying: + logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore + else: + logits = self.transformer.ff_out(x) # type: ignore + if self.config.scale_logits: + logits.mul_(1 / math.sqrt(self.config.d_model)) + + return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type] + + +def create_model_config_from_pretrained_config(config: LLaDAConfig): + """ + Utility function + """ + + kwargs = {} + for field in fields(ModelConfig): + kwargs[field.name] = getattr(config, field.name) + + model_config = ModelConfig(**kwargs) + return model_config + + +class LLaDAModelLM(PreTrainedModel): + """ + Extremely barebones HF model wrapper. + """ + + config_class = LLaDAConfig + base_model_prefix = "model" + _no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"] + + def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False): + super().__init__(config) + + if not model: + model_config = create_model_config_from_pretrained_config(config) + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + model_config.init_device = "cpu" + self.model = LLaDAModel(model_config, init_params=init_params) + else: + self.model = model + + def forward( + self, + input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x` + ) -> Union[Tuple, CausalLMOutputWithPast]: + if use_cache is None: + use_cache = self.config.use_cache + + if output_attentions: + raise ValueError("output_attentions is not yet supported in LLaDA") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.forward( + input_ids=input_ids, + input_embeddings=inputs_embeds, + attention_mask=attention_mask, + attention_bias=attention_bias, + past_key_values=None, + use_cache=False, + output_hidden_states=output_hidden_states, + ) + + logits = outputs.logits + hidden_states = outputs.hidden_states + + loss = None + if labels is not None: + import warnings + warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + logits=logits, + past_key_values=outputs.attn_key_values, + hidden_states=hidden_states, + ) + + def can_generate(self) -> bool: + return True + + def prepare_inputs_for_generation( + self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs + ): + if past_key_values: + # This is because we want the model to only process the last generated token. + input_ids = input_ids[:, -1:] + model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values} + + model_inputs.update(kwargs) + model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache) + return model_inputs + + # TODO: these are required to make the implementation complete. + # def resize_position_embeddings(self, new_num_position_embeddings: int): + # pass + # + # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]: + # pass + # + # def _reorder_cache(self, past_key_values, beam_idx): + # pass + + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.transformer.wte + + def set_input_embeddings(self, value: torch.nn.Module): + self.model.transformer.wte = value + + def get_output_embeddings(self): + if self.config.weight_tying: + return self.model.transformer.wte + else: + return self.model.transformer.ff_out + + def set_output_embeddings(self, value: torch.nn.Module): + if self.config.weight_tying: + self.model.transformer.wte = value + else: + self.model.transformer.ff_out = value + + def tie_weights(self): + if self.config.weight_tying: + self.model.transformer.ff_out = self.model.transformer.wte + +# Register the model so that it is available for transformer pipelines, auto-loading, etc. +AutoModel.register(LLaDAConfig, LLaDAModelLM) \ No newline at end of file diff --git a/MMaDA/models/modeling_magvitv2.py b/MMaDA/models/modeling_magvitv2.py new file mode 100644 index 0000000000000000000000000000000000000000..10ad088d283fed18c076dd2281f960b3fe02af56 --- /dev/null +++ b/MMaDA/models/modeling_magvitv2.py @@ -0,0 +1,441 @@ +from dataclasses import dataclass, field +import numpy as np +import torch +import torch.nn as nn +from .common_modules import * +from .modeling_utils import ConfigMixin, ModelMixin, register_to_config +from .misc import * +import math + +class Updateable: + def do_update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step( + epoch, global_step, on_load_weights=on_load_weights + ) + self.update_step(epoch, global_step, on_load_weights=on_load_weights) + + def do_update_step_end(self, epoch: int, global_step: int): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + self.update_step_end(epoch, global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # override this method to implement custom update logic + # if on_load_weights is True, you should be careful doing things related to model evaluations, + # as the models and tensors are not guarenteed to be on the same device + pass + + def update_step_end(self, epoch: int, global_step: int): + pass + +class VQGANEncoder(ModelMixin, ConfigMixin): + @dataclass + class Config: + ch: int = 128 + ch_mult: List[int] = field(default_factory=lambda: [1, 2, 2, 4, 4]) + num_res_blocks: List[int] = field(default_factory=lambda: [4, 3, 4, 3, 4]) + attn_resolutions: List[int] = field(default_factory=lambda: [5]) + dropout: float = 0.0 + in_ch: int = 3 + out_ch: int = 3 + resolution: int = 256 + z_channels: int = 13 + double_z: bool = False + + def __init__(self, + ch: int = 128, + ch_mult: List[int] = [1, 2, 2, 4, 4], + num_res_blocks: List[int] = [4, 3, 4, 3, 4], + attn_resolutions: List[int] = [5], + dropout: float = 0.0, + in_ch: int = 3, + out_ch: int = 3, + resolution: int = 256, + z_channels: int = 13, + double_z: bool = False): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_ch = in_ch + # downsampling + self.conv_in = torch.nn.Conv2d( + self.in_ch, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = self.resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = self.ch * in_ch_mult[i_level] + block_out = self.ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks[i_level]): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, True) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1) + # for param in self.parameters(): + # broadcast(param, src=0) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + # Segmentation Fault + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks[i_level]): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.quant_conv(h) + return h + + +class LFQuantizer(nn.Module): + def __init__(self, num_codebook_entry: int = -1, + codebook_dim: int = 13, + beta: float = 0.25, + entropy_multiplier: float = 0.1, + commit_loss_multiplier: float = 0.1, ): + super().__init__() + self.codebook_size = 2 ** codebook_dim + print( + f"Look-up free quantizer with codebook size: {self.codebook_size}" + ) + self.e_dim = codebook_dim + self.beta = beta + + indices = torch.arange(self.codebook_size) + + binary = ( + indices.unsqueeze(1) + >> torch.arange(codebook_dim - 1, -1, -1, dtype=torch.long) + ) & 1 + + embedding = binary.float() * 2 - 1 + self.register_buffer("embedding", embedding) + self.register_buffer( + "power_vals", 2 ** torch.arange(codebook_dim - 1, -1, -1) + ) + self.commit_loss_multiplier = commit_loss_multiplier + self.entropy_multiplier = entropy_multiplier + + def get_indices(self, z_q): + return ( + (self.power_vals.reshape(1, -1, 1, 1) * (z_q > 0).float()) + .sum(1, keepdim=True) + .long() + ) + + def get_codebook_entry(self, indices, shape=None): + if shape is None: + h, w = int(math.sqrt(indices.shape[-1])), int(math.sqrt(indices.shape[-1])) + else: + h, w = shape + b, _ = indices.shape + indices = indices.reshape(-1) + z_q = self.embedding[indices] + z_q = z_q.view(b, h, w, -1) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + def forward(self, z, get_code=False): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + if get_code: + return self.get_codebook_entry(z) + + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + ge_zero = (z_flattened > 0).float() + ones = torch.ones_like(z_flattened) + z_q = ones * ge_zero + -ones * (1 - ge_zero) + + # preserve gradients + z_q = z_flattened + (z_q - z_flattened).detach() + + # compute entropy loss + CatDist = torch.distributions.categorical.Categorical + logit = torch.stack( + [ + -(z_flattened - torch.ones_like(z_q)).pow(2), + -(z_flattened - torch.ones_like(z_q) * -1).pow(2), + ], + dim=-1, + ) + cat_dist = CatDist(logits=logit) + entropy = cat_dist.entropy().mean() + mean_prob = cat_dist.probs.mean(0) + mean_entropy = CatDist(probs=mean_prob).entropy().mean() + + # compute loss for embedding + commit_loss = torch.mean( + (z_q.detach() - z_flattened) ** 2 + ) + self.beta * torch.mean((z_q - z_flattened.detach()) ** 2) + + # reshape back to match original input shape + z_q = z_q.view(z.shape) + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return { + "z": z_q, + "quantizer_loss": commit_loss * self.commit_loss_multiplier, + "entropy_loss": (entropy - mean_entropy) * self.entropy_multiplier, + "indices": self.get_indices(z_q), + } + + +class VQGANDecoder(ModelMixin, ConfigMixin): + def __init__(self, ch: int = 128, + ch_mult: List[int] = [1, 1, 2, 2, 4], + num_res_blocks: List[int] = [4, 4, 3, 4, 3], + attn_resolutions: List[int] = [5], + dropout: float = 0.0, + in_ch: int = 3, + out_ch: int = 3, + resolution: int = 256, + z_channels: int = 13, + double_z: bool = False): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_ch = in_ch + self.give_pre_end = False + + self.z_channels = z_channels + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = self.resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks[i_level]): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, True) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + self.post_quant_conv = torch.nn.Conv2d( + z_channels, z_channels, 1 + ) + + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + # timestep embedding + temb = None + output = dict() + z = self.post_quant_conv(z) + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks[i_level]): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + output["output"] = h + if self.give_pre_end: + return output + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + output["output"] = h + return output + + +class MAGVITv2(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + ): + super().__init__() + + self.encoder = VQGANEncoder() + self.decoder = VQGANDecoder() + self.quantize = LFQuantizer() + + def forward(self, pixel_values, return_loss=False): + pass + + def encode(self, pixel_values, return_loss=False): + hidden_states = self.encoder(pixel_values) + quantized_states = self.quantize(hidden_states)['z'] + codebook_indices = self.quantize.get_indices(quantized_states).reshape(pixel_values.shape[0], -1) + output = (quantized_states, codebook_indices) + return output + + def get_code(self, pixel_values): + hidden_states = self.encoder(pixel_values) + codebook_indices = self.quantize.get_indices(self.quantize(hidden_states)['z']).reshape(pixel_values.shape[0], -1) + + return codebook_indices + + def decode_code(self, codebook_indices, shape=None): + z_q = self.quantize.get_codebook_entry(codebook_indices, shape=shape) + + reconstructed_pixel_values = self.decoder(z_q)["output"] + return reconstructed_pixel_values + + +if __name__ == '__main__': + encoder = VQGANEncoder() + import ipdb + ipdb.set_trace() + print() \ No newline at end of file diff --git a/MMaDA/models/modeling_mmada.py b/MMaDA/models/modeling_mmada.py new file mode 100644 index 0000000000000000000000000000000000000000..1633951d0242e03c03455bd2009d85b1b24726e3 --- /dev/null +++ b/MMaDA/models/modeling_mmada.py @@ -0,0 +1,1116 @@ +from __future__ import annotations + +import logging +import math +import sys +from abc import abstractmethod +from collections import defaultdict +from functools import partial +from typing import ( + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + cast, +) +from dataclasses import fields +from typing import List, Optional, Tuple, Union +import numpy as np +import torch +import torch.backends.cuda +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.auto import AutoModel, AutoConfig, AutoModelForCausalLM +from transformers.cache_utils import Cache +from PIL import Image +from .configuration_llada import ( + LLaDAConfig, + StrEnum, + InitFnType, + ActivationType, + BlockType, + LayerNormType, + ModelConfig, + ActivationCheckpointingStrategy, +) + +from .modeling_llada import LLaDAModelLM +from .modeling_video_encoder import VideoEncoder +from .sampling import cosine_schedule, mask_by_random_topk +from transformers import PretrainedConfig + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + +class MMadaConfig(PretrainedConfig): + model_type = "mmada" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + allowed_keys = [ + "vocab_size", + "llm_vocab_size", + "llm_model_path", + "codebook_size", + "num_vq_tokens", + "num_new_special_tokens", + "gradient_checkpointing", + "new_vocab_size", + ] + + for key in allowed_keys: + if key in kwargs: + setattr(self, key, kwargs[key]) + + + +class MMadaModelLM(LLaDAModelLM): + config_class = MMadaConfig + base_model_prefix = "model" + def __init__(self, config: MMadaConfig, *args, **kwargs): + print(f"Initializing MMadaModelLM with config: {config}") + super().__init__(config, *args, **kwargs) + + # # resize token embeddings + # print(f"Resizing token embeddings to {config.new_vocab_size}") + # self.resize_token_embeddings(config.new_vocab_size) + + @torch.no_grad() + def t2i_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + @torch.no_grad() + def t2s_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=100, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + + speech_vocab_start_idx = len(uni_prompting.text_tokenizer) + 8192 + speech_vocab_end_idx = speech_vocab_start_idx + 4096 + + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -num_vq_tokens-1:-1, speech_vocab_start_idx:speech_vocab_end_idx] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -num_vq_tokens-1:-1, speech_vocab_start_idx:speech_vocab_end_idx] + + # logits: 1, 1024, 4096 + print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + @torch.no_grad() + def i2i_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + def forward_process( + self, + input_ids, + labels, + batch_size_t2i=0, + batch_size_lm=0, + batch_size_mmu=0, + max_seq_length=128, + p_mask_lm=None, + p_mask_mmu=None, + answer_lengths=None, + t2i_masks=None, + answer_lengths_lm=None + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + if batch_size_t2i == 0: + loss_t2i = torch.tensor(0.0, device=input_ids.device) + else: + loss_t2i = F.cross_entropy( + logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + ) + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_lm = masked_indices[batch_size_t2i:batch_size_t2i + batch_size_lm] + masked_indices_mmu = masked_indices[-batch_size_mmu:] + p_mask_lm = p_mask_lm.to(masked_indices_lm.device) + p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device) + answer_lengths = answer_lengths.to(masked_indices_mmu.device) + loss_lm = F.cross_entropy( + logits[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1, self.output_size), + labels[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_lm[masked_indices_lm] + + if answer_lengths_lm is not None: + loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0]) + else: + loss_lm = loss_lm.sum() / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0] * logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[1]) + + loss_mmu = F.cross_entropy( + logits[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1, self.output_size), + labels[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_mmu[masked_indices_mmu] + loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[-batch_size_mmu:].shape[0]) + + return logits, loss_t2i, loss_lm, loss_mmu + + def forward_process_with_r2i( + self, + input_ids, + labels, + t2i_masks=None, + max_seq_length=128, + batch_size_t2i=0, + batch_size_lm=0, + batch_size_mmu=0, + batch_size_r2i=0, + p_mask_lm=None, + p_mask_mmu=None, + p_mask_r2i=None, + answer_lengths=None, + answer_lengths_lm=None, + answer_lengths_r2i=None, + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + logits = self(input_ids, attention_bias=attention_bias).logits + # logits = self(input_ids).logits + self.output_size = logits.shape[-1] + + if batch_size_t2i == 0: + loss_t2i = torch.tensor(0.0, device=input_ids.device) + else: + # t2i loss + loss_t2i = F.cross_entropy( + logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + ) + + # llada loss + + start_lm = batch_size_t2i + end_lm = start_lm + batch_size_lm + start_mmu = end_lm + end_mmu = start_mmu + batch_size_mmu + start_r2i = end_mmu + end_r2i = start_r2i + batch_size_r2i + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_lm = masked_indices[start_lm:end_lm] + masked_indices_mmu = masked_indices[start_mmu:end_mmu] + masked_indices_r2i = masked_indices[start_r2i:end_r2i] + + p_mask_lm = p_mask_lm.to(masked_indices_lm.device) + p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device) + p_mask_r2i = p_mask_r2i.to(masked_indices_r2i.device) + + answer_lengths = answer_lengths.to(masked_indices_mmu.device) + answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device) + answer_lengths_r2i = answer_lengths_r2i.to(masked_indices_r2i.device) + + loss_lm = F.cross_entropy( + logits[start_lm:end_lm][masked_indices_lm].contiguous().view(-1, self.output_size), + labels[start_lm:end_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_lm[masked_indices_lm] + + if answer_lengths_lm is not None: + loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[start_lm:end_lm].shape[0]) + else: + loss_lm = loss_lm.sum() / (logits[start_lm:end_lm].shape[0] * logits[start_lm:end_lm].shape[1]) + + loss_mmu = F.cross_entropy( + logits[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1, self.output_size), + labels[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_mmu[masked_indices_mmu] + loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[start_mmu:end_mmu].shape[0]) + + loss_r2i = F.cross_entropy( + logits[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1, self.output_size), + labels[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_r2i[masked_indices_r2i] + loss_r2i = torch.sum(loss_r2i/answer_lengths_r2i[masked_indices_r2i]) / (logits[start_r2i:end_r2i].shape[0]) + + return logits, loss_t2i, loss_lm, loss_mmu, loss_r2i + + def forward_t2i( + self, + input_ids, + labels, + batch_size_t2i=0, + max_seq_length=128, + t2i_masks=None + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + logits = self(input_ids, attention_bias=attention_bias).logits + # logits = self(input_ids).logits + self.output_size = logits.shape[-1] + + # print(f"logits shape: {logits.shape}") B, 359, vocab_size + + loss_t2i = F.cross_entropy( + logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + ) + + return loss_t2i + + # Temp + def forward_i2i(self, input_ids, attention_mask, labels): + """ + Forward pass for the I2I task. + """ + outputs = self( + input_ids=input_ids, + attention_mask=attention_mask + ) + logits = outputs.logits + + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + ignore_index=-100 + ) + + return logits, loss + + # Temp + def forward_s2t( + self, + input_ids, + labels, + batch_size_s2t=0, + max_seq_length=128, + p_mask_s2t=None, + answer_lengths=None, + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_s2t = masked_indices[-batch_size_s2t:] + p_mask_s2t = p_mask_s2t.to(masked_indices_s2t.device) + answer_lengths = answer_lengths.to(masked_indices_s2t.device) + + loss_s2t = F.cross_entropy( + logits[-batch_size_s2t:][masked_indices_s2t].contiguous().view(-1, self.output_size), + labels[-batch_size_s2t:][masked_indices_s2t].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_s2t[masked_indices_s2t] + loss_s2t = torch.sum(loss_s2t/answer_lengths[masked_indices_s2t]) / (logits[-batch_size_s2t:].shape[0]) + + return logits, loss_s2t + + def forward_t2s( + self, + input_ids, + labels, + batch_size_t2s=0, + max_seq_length=128, + p_mask_t2s=None, + answer_lengths=None, + ): + """ + Forward pass for text-to-speech (T2S) diffusion LM training. + + Args: + input_ids: (B, L) Input token IDs (text + [MASK]*len(speech)). + labels: (B, L) Target speech codebook token IDs. + batch_size_t2s: Batch size for t2s task (for multitask batches). + max_seq_length: Prompt(text) źøøģ“ + p_mask_t2s: (B, L) Mask probability per position (optional). + answer_lengths: (B,) 각 row별 target length (optional). + Returns: + logits, loss_t2s + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_t2s = masked_indices[-batch_size_t2s:] + p_mask_t2s = p_mask_t2s.to(masked_indices_t2s.device) + answer_lengths = answer_lengths.to(masked_indices_t2s.device) + + loss_t2s = F.cross_entropy( + logits[-batch_size_t2s:][masked_indices_t2s].contiguous().view(-1, self.output_size), + labels[-batch_size_t2s:][masked_indices_t2s].contiguous().view(-1), + ignore_index=-100, reduction='none' + ) / p_mask_t2s[masked_indices_t2s] + loss_t2s = torch.sum(loss_t2s / answer_lengths[masked_indices_t2s]) / logits[-batch_size_t2s:].shape[0] + + return logits, loss_t2s + + def forward_v2t( + self, + input_ids, + labels, + batch_size_v2t=0, + max_seq_length=128, + p_mask_v2t=None, + answer_lengths=None, + ): + """ + video-to-text (V2T) diffusion LM training. + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_v2t = masked_indices[:batch_size_v2t] + p_mask_v2t = p_mask_v2t.to(masked_indices_v2t.device) + answer_lengths = answer_lengths.to(masked_indices_v2t.device) + + loss_v2t = F.cross_entropy( + logits[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1, self.output_size), + labels[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1), + ignore_index=-100, + reduction='none' + ) / p_mask_v2t[masked_indices_v2t] + loss_v2t = torch.sum(loss_v2t / answer_lengths[masked_indices_v2t]) / (logits[:batch_size_v2t].shape[0]) + return logits, loss_v2t + + def forward_v2t_encoder( + self, + input_ids, + labels, + batch_size_v2t=0, + max_seq_length=128, + p_mask_v2t=None, + answer_lengths=None, + ): + """ + video-to-text (V2T) diffusion LM training. + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + input_embeddings = super().model.transformer.wte(input_ids) + + + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_v2t = masked_indices[:batch_size_v2t] + p_mask_v2t = p_mask_v2t.to(masked_indices_v2t.device) + answer_lengths = answer_lengths.to(masked_indices_v2t.device) + + loss_v2t = F.cross_entropy( + logits[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1, self.output_size), + labels[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1), + ignore_index=-100, + reduction='none' + ) / p_mask_v2t[masked_indices_v2t] + loss_v2t = torch.sum(loss_v2t / answer_lengths[masked_indices_v2t]) / (logits[:batch_size_v2t].shape[0]) + return logits, loss_v2t + + + # def forward_i2i(self, input_ids, attention_mask, labels, max_prompt_length): + # """ + # Forward pass for the I2I task. + # """ + # outputs = self( + # input_ids=input_ids, + # attention_mask=attention_mask + # ) + # logits = outputs.logits + + # logits_for_loss = logits[:, max_prompt_length:].contiguous() + # labels_for_loss = labels[:, max_prompt_length:].contiguous() + + # loss = F.cross_entropy( + # logits_for_loss.view(-1, logits_for_loss.size(-1)), + # labels_for_loss.view(-1), + # ignore_index=-100 + # ) + + # return logits, loss + + @torch.no_grad() + def mmu_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + # print(f"num_blocks: {num_blocks}, steps: {steps}") + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}") + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + + # logits = logits[:, -1, :] / temperature + # # optionally crop the logits to only the top k options + # if top_k is not None: + # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # logits[logits < v[:, [-1]]] = -float('Inf') + # # apply softmax to convert logits to (normalized) probabilities + # probs = F.softmax(logits, dim=-1) + # # sample from the distribution + # idx_next = torch.multinomial(probs, num_samples=1) + # result.append(idx_next[0][0]) + # # append sampled index to the running sequence and continue + # if self.config.w_clip_vit: + # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next) + # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) + # else: + # idx = torch.cat((idx, idx_next), dim=1) + + # if eot_token is not None and idx_next.cpu() == eot_token: + # break + + return x + + + @torch.no_grad() + def s2t_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + # print(f"num_blocks: {num_blocks}, steps: {steps}") + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}") + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + + # logits = logits[:, -1, :] / temperature + # # optionally crop the logits to only the top k options + # if top_k is not None: + # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # logits[logits < v[:, [-1]]] = -float('Inf') + # # apply softmax to convert logits to (normalized) probabilities + # probs = F.softmax(logits, dim=-1) + # # sample from the distribution + # idx_next = torch.multinomial(probs, num_samples=1) + # result.append(idx_next[0][0]) + # # append sampled index to the running sequence and continue + # if self.config.w_clip_vit: + # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next) + # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) + # else: + # idx = torch.cat((idx, idx_next), dim=1) + + # if eot_token is not None and idx_next.cpu() == eot_token: + # break + + return x + + @torch.no_grad() + def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + if eot_token is not None: + last_token_index_in_current_block = idx.shape[1] + (num_block + 1) * block_length - 1 + if last_token_index_in_current_block < x.shape[1]: + tokens_at_block_end = x[:, last_token_index_in_current_block] + if torch.all(tokens_at_block_end == eot_token): + break + return x + + @torch.no_grad() + def t2i_generate_decoding_stepwise( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + vq_model = None, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + current_image_vq_indices = sampled_ids.clone() + # print(f"current_image_vq_indices: {current_image_vq_indices}") + current_image_vq_indices = torch.clamp(current_image_vq_indices, 0, 8192 - 1) + current_image = vq_model.decode_code(current_image_vq_indices) + images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = Image.fromarray(images[0]) + yield pil_images, f"Step {step + 1}/{timesteps}" + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + + return sampled_ids + + +AutoConfig.register("mmada", MMadaConfig) +AutoModelForCausalLM.register(MMadaConfig, MMadaModelLM) +AutoModel.register(MMadaConfig, MMadaModelLM) diff --git a/MMaDA/models/modeling_omada.py b/MMaDA/models/modeling_omada.py new file mode 100644 index 0000000000000000000000000000000000000000..368d8c8c3d5f14d0504714564a7037ae3608dc5c --- /dev/null +++ b/MMaDA/models/modeling_omada.py @@ -0,0 +1,1990 @@ +from __future__ import annotations + +import logging +import math +import sys +import warnings +from abc import abstractmethod +from collections import defaultdict +from functools import partial +from typing import ( + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + cast, +) +from dataclasses import fields +from typing import List, Optional, Tuple, Union +import numpy as np +import torch +import torch.backends.cuda +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.auto import AutoModel, AutoConfig, AutoModelForCausalLM +from transformers.cache_utils import Cache +from PIL import Image +from .configuration_llada import ( + LLaDAConfig, + StrEnum, + InitFnType, + ActivationType, + BlockType, + LayerNormType, + ModelConfig, + ActivationCheckpointingStrategy, +) + +from .modeling_llada import LLaDAModelLM +from .modeling_video_encoder import VideoEncoder +from .sampling import cosine_schedule, mask_by_random_topk +from transformers import PretrainedConfig + +def calculate_mmu_style_loss(logits_batch, labels_batch, masked_indices_batch, p_mask, answer_lengths, output_size, device): + if logits_batch.shape[0] == 0: + return logits_batch.new_zeros(()) + + p_mask_flat = p_mask.to(device)[masked_indices_batch] + p_mask_flat = torch.clamp(p_mask_flat, min=1e-4) + answer_lengths_flat = answer_lengths.to(device)[masked_indices_batch] + answer_lengths_flat = torch.clamp(answer_lengths_flat, min=1) + + loss = F.cross_entropy( + logits_batch[masked_indices_batch].contiguous().view(-1, output_size), + labels_batch[masked_indices_batch].contiguous().view(-1), ignore_index=-100, reduction='none' + ) / p_mask_flat + + loss = torch.sum(loss / answer_lengths_flat) / logits_batch.shape[0] + return loss + + +def calculate_t2s_loss( + logits_batch, + labels_batch, + masked_indices_batch, + p_mask, + answer_lengths, + vocab_start, + codebook_size, + eoa_token_id, + eos_token_id, + device, + ignore_index=-100, +): + if logits_batch.shape[0] == 0: + return logits_batch.new_zeros(()) + + selected_logits = logits_batch[masked_indices_batch] + selected_labels = labels_batch[masked_indices_batch].to(torch.long) + + if selected_logits.shape[0] == 0: + return logits_batch.new_zeros(()) + + work_dtype = torch.float32 + selected_logits_fp32 = selected_logits.to(dtype=work_dtype) + + speech_logits = selected_logits_fp32[:, vocab_start : vocab_start + codebook_size] + eoa_logits = selected_logits_fp32[:, eoa_token_id : eoa_token_id + 1] + eos_logits = selected_logits_fp32[:, eos_token_id : eos_token_id + 1] + combined_logits = torch.cat([speech_logits, eoa_logits, eos_logits], dim=-1) + + p_mask_flat = p_mask.to(device=device, dtype=work_dtype)[masked_indices_batch] + p_mask_flat = torch.clamp(p_mask_flat, min=1e-4) + answer_lengths_flat = answer_lengths.to(device=device, dtype=work_dtype)[masked_indices_batch] + answer_lengths_flat = torch.clamp(answer_lengths_flat, min=1.0) + + relative_labels = torch.full_like(selected_labels, ignore_index) + audio_mask = (selected_labels >= vocab_start) & (selected_labels < vocab_start + codebook_size) + relative_labels[audio_mask] = selected_labels[audio_mask] - vocab_start + relative_labels[selected_labels == eoa_token_id] = codebook_size + relative_labels[selected_labels == eos_token_id] = codebook_size + 1 + + loss_vec = F.cross_entropy( + combined_logits, + relative_labels, + ignore_index=ignore_index, + reduction='none' + ) + + loss_vec = loss_vec / p_mask_flat + loss_vec = loss_vec / answer_lengths_flat + + loss = torch.sum(loss_vec) / logits_batch.shape[0] + return loss.to(dtype=logits_batch.dtype) + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + +class OMadaConfig(PretrainedConfig): + model_type = "omada" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + allowed_keys = [ + "vocab_size", + "llm_vocab_size", + "llm_model_path", + "codebook_size", + "num_vq_tokens", + "num_new_special_tokens", + "gradient_checkpointing", + "new_vocab_size", + ] + + for key in allowed_keys: + if key in kwargs: + setattr(self, key, kwargs[key]) + + + +class OMadaModelLM(LLaDAModelLM): + config_class = OMadaConfig + base_model_prefix = "model" + def __init__(self, config: OMadaConfig, *args, **kwargs): + print(f"Initializing OMadaModelLM with config: {config}") + super().__init__(config, *args, **kwargs) + + # # resize token embeddings + # print(f"Resizing token embeddings to {config.new_vocab_size}") + # self.resize_token_embeddings(config.new_vocab_size) + + @torch.no_grad() + def t2i_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + @torch.no_grad() + def t2s_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, + guidance_scale=0, + noise_schedule=None, + generator: torch.Generator = None, + config=None, + seq_len=256, + mask_token_id=126336, + **kwargs, + ): + uni_prompting = kwargs.get("uni_prompting", None) + if uni_prompting is None: + raise ValueError("uni_prompting object must be provided in kwargs.") + + eoa_token_id = uni_prompting.sptids_dict['<|eoa|>'][0].item() + eos_token_id = uni_prompting.text_tokenizer.eos_token_id + + num_vq_tokens = (input_ids == mask_token_id).sum(dim=-1).max().item() + if num_vq_tokens == 0: + raise ValueError("No mask tokens found in input_ids.") + + speech_vocab_start_idx = len(uni_prompting.text_tokenizer) + 8192 + speech_vocab_end_idx = speech_vocab_start_idx + 4096 + + # VQ Codes: 0 ~ 4095 + # EOA: 4096 + # EOS: 4097 + vq_code_relative_eoa_id = 4096 + vq_code_relative_eos_id = 4097 + + input_ids_relative = input_ids[:, -(num_vq_tokens):].clone() + input_ids_relative = torch.where( + input_ids_relative == mask_token_id, + mask_token_id, + input_ids_relative - speech_vocab_start_idx + ) + + if uncond_input_ids is not None: + start_gen_idx = (uncond_input_ids[0] == uni_prompting.sptids_dict['<|soa|>'][0].item()).nonzero(as_tuple=True)[0][0].item() + 1 + uncond_prefix = uncond_input_ids[:, :start_gen_idx] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat([uncond_prefix, input_ids[:, start_gen_idx:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + + logits_vq = logits[:, -(num_vq_tokens):, speech_vocab_start_idx:speech_vocab_end_idx] + logits_eoa = logits[:, -(num_vq_tokens):, eoa_token_id:eoa_token_id+1] + logits_eos = logits[:, -(num_vq_tokens):, eos_token_id:eos_token_id+1] + + combined_logits = torch.cat([logits_vq, logits_eoa, logits_eos], dim=-1) + + probs = combined_logits.softmax(dim=-1) + sampled = probs.reshape(-1, combined_logits.size(-1)) + + sampled_ids_relative = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*combined_logits.shape[:-1]) + + unknown_map = input_ids_relative == mask_token_id + + sampled_ids_relative = torch.where(unknown_map, sampled_ids_relative, input_ids_relative) + + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio, device=logits.device)) + + selected_probs = torch.gather(probs, -1, sampled_ids_relative.long()[..., None]).squeeze(-1) + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + mask_len = torch.max( + torch.tensor([1], device=logits.device), + torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + + input_ids[:, -(num_vq_tokens):] = torch.where( + masking, + mask_token_id, + torch.where( + sampled_ids_relative == vq_code_relative_eos_id, + eos_token_id, + torch.where( + sampled_ids_relative == vq_code_relative_eoa_id, + eoa_token_id, + sampled_ids_relative + speech_vocab_start_idx + ) + ) + ) + + input_ids_relative = torch.where(masking, mask_token_id, sampled_ids_relative) + + # print("--- Generation Loop Finished ---") + # print("Final sequence BEFORE post-processing (relative IDs):") + # print(input_ids_relative[0]) + # print(f"Shape: {input_ids_relative.shape}") + # print("---------------------------------") + + final_output_ids = [] + for i in range(input_ids_relative.shape[0]): + seq = input_ids_relative[i] + + eoa_indices = (seq >= vq_code_relative_eoa_id).nonzero(as_tuple=True)[0] + + if eoa_indices.numel() > 0: + first_eoa_idx = eoa_indices[0] + seq = seq[:first_eoa_idx] + + valid_tokens = seq[seq != mask_token_id] + + final_output_ids.append(valid_tokens) + + return final_output_ids + + @torch.no_grad() + def t2s_generate_mmu_like( + self, + input_ids: torch.LongTensor, + max_new_tokens: Optional[int] = None, + steps: int = 256, + block_length: int = 128, + temperature: float = 0.0, + cfg_scale: float = 0.0, + mask_token_id: int = 126336, + attention_mask: Optional[torch.LongTensor] = None, + uni_prompting=None, + codebook_size: Optional[int] = None, + audio_codebook_size: int = 4096, + ): + """ + Generate speech tokens with MMU-style block-wise refinement. + Assumes the speech region within ``input_ids`` is contiguous and filled with ``mask_token_id`` + prior to generation. + """ + + if uni_prompting is None: + raise ValueError("uni_prompting must be provided") + if block_length <= 0: + raise ValueError("block_length must be positive") + + batch_size, seq_len = input_ids.shape + device = input_ids.device + + mask_positions_full = (input_ids == mask_token_id) + if not mask_positions_full.any(): + raise ValueError("No mask tokens detected for T2S generation") + + mask_cols = torch.where(mask_positions_full[0])[0] + speech_region_start = mask_cols[0].item() + speech_region_len = mask_cols.numel() + + mask_counts = mask_positions_full.sum(dim=1) + if not torch.all(mask_counts == mask_counts[0]): + raise ValueError("All batch items must contain the same number of masked speech tokens for MMU-like generation") + + if max_new_tokens is None: + max_new_tokens = speech_region_len + else: + max_new_tokens = min(max_new_tokens, speech_region_len) + + block_length = max(1, min(block_length, max_new_tokens)) + num_blocks = math.ceil(max_new_tokens / block_length) + inner_steps = max(1, steps // num_blocks) + + codebook_base = codebook_size if codebook_size is not None else getattr(self.config, "codebook_size", 8192) + speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_base + speech_vocab_end = speech_vocab_start + audio_codebook_size + + eoa_token_id = uni_prompting.sptids_dict['<|eoa|>'][0].item() + eos_token_id = uni_prompting.text_tokenizer.eos_token_id + vq_code_relative_eoa_id = audio_codebook_size + vq_code_relative_eos_id = audio_codebook_size + 1 + + work = input_ids.clone() + + attention_bias = None + if attention_mask is not None: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + + speech_indices = mask_cols[:max_new_tokens] + + for block_idx in range(num_blocks): + block_start = block_idx * block_length + block_end = min(block_start + block_length, max_new_tokens) + curr_indices = speech_indices[block_start:block_end] + if curr_indices.numel() == 0: + continue + + block_mask = mask_positions_full[:, curr_indices] + num_transfer_tokens = get_num_transfer_tokens(block_mask, inner_steps) + + for inner_step in range(inner_steps): + if cfg_scale > 0.0: + un_cond = work.clone() + un_cond[:, speech_indices] = mask_token_id + stacked = torch.cat([work, un_cond], dim=0) + if attention_bias is not None: + att_bias = torch.cat([attention_bias, attention_bias], dim=0) + else: + att_bias = None + logits = self(stacked, attention_bias=att_bias).logits + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits) + else: + logits = self(work, attention_bias=attention_bias).logits + + logits_block = logits.index_select(1, curr_indices.to(device)) + logits_vq = logits_block[:, :, speech_vocab_start:speech_vocab_end] + logits_eoa = logits_block[:, :, eoa_token_id:eoa_token_id + 1] + logits_eos = logits_block[:, :, eos_token_id:eos_token_id + 1] + + combined_logits = torch.cat([logits_vq, logits_eoa, logits_eos], dim=-1) + if temperature > 0.0: + combined_logits = combined_logits / max(temperature, 1e-5) + probs = F.softmax(combined_logits, dim=-1) + + sampled = torch.multinomial( + probs.view(-1, probs.size(-1)), 1 + ).view(batch_size, curr_indices.numel()) + + selected_probs = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) + + eos_tensor = sampled.new_full(sampled.shape, eos_token_id) + eoa_tensor = sampled.new_full(sampled.shape, eoa_token_id) + sampled_absolute = torch.where( + sampled == vq_code_relative_eos_id, + eos_tensor, + torch.where( + sampled == vq_code_relative_eoa_id, + eoa_tensor, + sampled + speech_vocab_start + ) + ) + + current_block_vals = work.index_select(1, curr_indices) + mask_current = current_block_vals == mask_token_id + + confidence = torch.where( + mask_current, + selected_probs, + torch.full_like(selected_probs, float('-inf')) + ) + + finalize = torch.zeros_like(mask_current, dtype=torch.bool) + for b in range(batch_size): + available = mask_current[b].sum().item() + if available == 0: + continue + transfer = min(int(num_transfer_tokens[b, inner_step].item()), available) + if transfer <= 0: + continue + _, idxs = torch.topk(confidence[b], k=transfer, largest=True) + finalize[b, idxs] = True + + mask_fill = sampled_absolute.new_full(sampled_absolute.shape, mask_token_id) + updates = torch.where(finalize, sampled_absolute, mask_fill) + new_block = torch.where(mask_current, updates, current_block_vals) + + work[:, curr_indices] = new_block + + mask_positions_full[:, curr_indices] = new_block == mask_token_id + + if not mask_positions_full[:, curr_indices].any(): + break + + final_outputs = [] + audio_slice = slice(speech_region_start, speech_region_start + speech_region_len) + audio_region = work[:, audio_slice] + + for seq in audio_region: + mask_tensor = seq.new_full(seq.shape, mask_token_id) + rel_eoa = seq.new_full(seq.shape, vq_code_relative_eoa_id) + rel_eos = seq.new_full(seq.shape, vq_code_relative_eos_id) + relative = torch.where( + seq == mask_token_id, + mask_tensor, + torch.where( + seq == eoa_token_id, + rel_eoa, + torch.where( + seq == eos_token_id, + rel_eos, + seq - speech_vocab_start + ) + ) + ) + + eoa_positions = (relative >= vq_code_relative_eoa_id).nonzero(as_tuple=True)[0] + if eoa_positions.numel() > 0: + relative = relative[:eoa_positions[0]] + + final_outputs.append(relative[relative != mask_token_id]) + + return final_outputs + + @torch.no_grad() + def t2s_fixed_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, + guidance_scale=0, + noise_schedule=None, + generator: torch.Generator = None, + config=None, + seq_len=256, + mask_token_id=126336, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens - 8192) + + # for classifier-free guidance + if uncond_input_ids is not None: + start_gen_idx = (uncond_input_ids[0] == uni_prompting.sptids_dict['<|soa|>'][0].item()).nonzero(as_tuple=True)[0][0].item() + 1 + uncond_prefix = uncond_input_ids[:, :start_gen_idx] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, start_gen_idx:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 : len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 + 4096] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 : len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 + 4096] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens + 8192) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + @torch.no_grad() + def i2i_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + # def forward_process( + # self, + # input_ids, + # labels, + # batch_size_t2i=0, + # batch_size_lm=0, + # batch_size_mmu=0, + # batch_size_v2t=0, + # batch_size_s2t=0, + # batch_size_t2s=0, + # max_seq_length=128, + # p_mask_lm=None, + # p_mask_mmu=None, + # p_mask_vid=None, + # p_mask_s2t=None, + # p_mask_t2s=None, + # answer_lengths=None, + # t2i_masks=None, + # answer_lengths_lm=None + # ): + # # attention bias, True for batch_size, 1, seq_len, seq_len + # attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + # attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + # attention_bias[:batch_size_t2i] = attention_bias_t2i + # logits = self(input_ids, attention_bias=attention_bias).logits + # self.output_size = logits.shape[-1] + + # if batch_size_t2i == 0: + # loss_t2i = torch.tensor(0.0, device=input_ids.device) + # else: + # loss_t2i = F.cross_entropy( + # logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + # labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + # ) + + # masked_indices = input_ids == self.config.mask_token_id + # masked_indices_lm = masked_indices[batch_size_t2i:batch_size_t2i + batch_size_lm] + # masked_indices_mmu = masked_indices[-batch_size_mmu:] + # p_mask_lm = p_mask_lm.to(masked_indices_lm.device) + # p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device) + # answer_lengths = answer_lengths.to(masked_indices_mmu.device) + # loss_lm = F.cross_entropy( + # logits[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1, self.output_size), + # labels[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + # )/p_mask_lm[masked_indices_lm] + + # if answer_lengths_lm is not None: + # loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0]) + # else: + # loss_lm = loss_lm.sum() / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0] * logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[1]) + + # loss_mmu = F.cross_entropy( + # logits[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1, self.output_size), + # labels[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none' + # )/p_mask_mmu[masked_indices_mmu] + # loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[-batch_size_mmu:].shape[0]) + + # return logits, loss_t2i, loss_lm, loss_mmu + + # def forward_process( + # self, + # input_ids, + # labels, + # batch_size_t2i=0, + # batch_size_lm=0, + # batch_size_mmu=0, + # batch_size_v2t=0, + # batch_size_s2t=0, + # batch_size_t2s=0, + # max_seq_length=128, + # p_mask_lm=None, + # p_mask_mmu=None, + # p_mask_vid=None, + # p_mask_s2t=None, + # p_mask_t2s=None, + # answer_lengths_lm=None, + # answer_lengths_mmu=None, + # answer_lengths_vid=None, + # answer_lengths_s2t=None, + # answer_lengths_t2s=None, + # t2i_masks=None, + # t2s_vocab_start=None, + # t2s_codebook_size=None, + # t2s_special_token_ids=None + # ): + # # --- 1. Attention Bias Setup (no changes) --- + # attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + # if batch_size_t2i > 0 and t2i_masks is not None: + # attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + # attention_bias[:batch_size_t2i] = attention_bias_t2i + + # # --- 2. Model Forward Pass (no changes) --- + # logits = self(input_ids, attention_bias=attention_bias).logits + # self.output_size = logits.shape[-1] + + # # --- 3. Loss Calculation --- + # device = input_ids.device + # zero_loss = torch.tensor(0.0, device=device) + + # # Calculate masked indices for the entire batch + # masked_indices = (input_ids == self.config.mask_token_id) + + # current_idx = 0 + + # # --- T2I Loss --- + # if batch_size_t2i > 0: + # loss_t2i = F.cross_entropy( + # logits[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + # labels[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + # ) + # else: + # loss_t2i = zero_loss + # current_idx += batch_size_t2i + + # # --- LM Loss --- + # if batch_size_lm > 0: + # start, end = current_idx, current_idx + batch_size_lm + # logits_lm, labels_lm = logits[start:end], labels[start:end] + # masked_indices_lm = masked_indices[start:end] + + # loss_lm = F.cross_entropy( + # logits_lm[masked_indices_lm].contiguous().view(-1, self.output_size), + # labels_lm[masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + # ) / p_mask_lm.to(device)[masked_indices_lm] + + # if answer_lengths_lm is not None: + # loss_lm = torch.sum(loss_lm / answer_lengths_lm.to(device)[masked_indices_lm]) / logits_lm.shape[0] + # else: + # loss_lm = loss_lm.sum() / logits_lm.shape[0] + # else: + # loss_lm = zero_loss + # current_idx += batch_size_lm + + # # --- MMU Loss --- + # if batch_size_mmu > 0: + # start, end = current_idx, current_idx + batch_size_mmu + # loss_mmu = calculate_mmu_style_loss( + # logits[start:end], labels[start:end], masked_indices[start:end], + # p_mask_mmu, answer_lengths_mmu, self.output_size, device + # ) + # else: + # loss_mmu = zero_loss + # current_idx += batch_size_mmu + + # # --- VID (V2T) Loss --- + # if batch_size_v2t > 0: + # start, end = current_idx, current_idx + batch_size_v2t + # loss_vid = calculate_mmu_style_loss( + # logits[start:end], labels[start:end], masked_indices[start:end], + # p_mask_vid, answer_lengths_vid, self.output_size, device + # ) + # else: + # loss_vid = zero_loss + # current_idx += batch_size_v2t + + # # --- S2T Loss --- + # if batch_size_s2t > 0: + # start, end = current_idx, current_idx + batch_size_s2t + # loss_s2t = calculate_mmu_style_loss( + # logits[start:end], labels[start:end], masked_indices[start:end], + # p_mask_s2t, answer_lengths_s2t, self.output_size, device + # ) + # else: + # loss_s2t = zero_loss + # current_idx += batch_size_s2t + + # # --- T2S Loss --- + # if batch_size_t2s > 0: + # start, end = current_idx, current_idx + batch_size_t2s + # if ( + # t2s_vocab_start is not None + # and t2s_codebook_size is not None + # and t2s_special_token_ids is not None + # ): + # eoa_id = t2s_special_token_ids.get('eoa') + # eos_id = t2s_special_token_ids.get('eos') + # else: + # eoa_id = eos_id = None + + # if eoa_id is not None and eos_id is not None: + # loss_t2s = calculate_t2s_loss( + # logits[start:end], + # labels[start:end], + # masked_indices[start:end], + # p_mask_t2s, + # answer_lengths_t2s, + # t2s_vocab_start, + # t2s_codebook_size, + # eoa_id, + # eos_id, + # device, + # ignore_index=-100, + # ) + # else: + # loss_t2s = calculate_mmu_style_loss( + # logits[start:end], labels[start:end], masked_indices[start:end], + # p_mask_t2s, answer_lengths_t2s, self.output_size, device + # ) + # else: + # loss_t2s = zero_loss + # current_idx += batch_size_t2s + + # return logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s + + def forward_process( + self, + input_ids, + labels, + batch_size_t2i=0, + batch_size_i2i=0, + batch_size_lm=0, + batch_size_mmu=0, + batch_size_v2t=0, + batch_size_v2s=0, + batch_size_s2t=0, + batch_size_s2s=0, + batch_size_t2s=0, + max_seq_length=128, + p_mask_lm=None, + p_mask_mmu=None, + p_mask_vid=None, + p_mask_v2s=None, + p_mask_s2t=None, + p_mask_s2s=None, + p_mask_t2s=None, + answer_lengths_lm=None, + answer_lengths_mmu=None, + answer_lengths_vid=None, + answer_lengths_v2s=None, + answer_lengths_s2t=None, + answer_lengths_s2s=None, + answer_lengths_t2s=None, + t2i_masks=None, + attention_masks_i2i=None, + t2s_vocab_start=None, + t2s_codebook_size=None, + t2s_special_token_ids=None, + text_vocab_size_override=None + ): + # --- 1. Attention Bias Setup (no changes) --- + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + if batch_size_t2i > 0 and t2i_masks is not None: + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + + if batch_size_i2i > 0 and attention_masks_i2i is not None: + start_i2i = batch_size_t2i + end_i2i = start_i2i + batch_size_i2i + attn_mask = attention_masks_i2i.to(input_ids.device) + if attn_mask.dtype != torch.bool: + attn_mask = attn_mask.bool() + attention_bias_i2i = (attn_mask[:, :, None] & attn_mask[:, None, :]).unsqueeze(1) + attention_bias[start_i2i:end_i2i] = attention_bias_i2i + + # --- 2. Model Forward Pass (no changes) --- + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + # --- 3. Loss Calculation --- + device = input_ids.device + zero_loss = torch.tensor(0.0, device=device) + + # Calculate masked indices for the entire batch + masked_indices = (input_ids == self.config.mask_token_id) + + text_vocab_size = text_vocab_size_override + image_vocab_size = getattr(self.config, "codebook_size", 0) + image_vocab_start = text_vocab_size + image_vocab_end = min(image_vocab_start + image_vocab_size, logits.shape[-1]) + current_idx = 0 + + # --- T2I Loss --- + if batch_size_t2i > 0: + logits_t2i = logits[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:] + labels_t2i = labels[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:] + if image_vocab_size <= 0: + warnings.warn("t2i encountered non-positive image vocab size; skipping loss.") + loss_t2i = zero_loss + else: + effective_vocab = image_vocab_end - image_vocab_start + if effective_vocab <= 0: + warnings.warn("t2i effective image vocab is invalid; skipping loss.") + loss_t2i = zero_loss + else: + logits_slice = logits_t2i[..., image_vocab_start:image_vocab_end] + labels_relative = torch.full_like(labels_t2i, -100) + valid_mask = (labels_t2i >= image_vocab_start) & (labels_t2i < image_vocab_end) + if not valid_mask.any(): + warnings.warn("t2i labels contain no valid image tokens; skipping loss.") + loss_t2i = zero_loss + else: + labels_relative[valid_mask] = labels_t2i[valid_mask] - image_vocab_start + loss_t2i = F.cross_entropy( + logits_slice.contiguous().view(-1, effective_vocab), + labels_relative.contiguous().view(-1), + ignore_index=-100, + ) + else: + loss_t2i = zero_loss + current_idx += batch_size_t2i + + # --- I2I Loss --- + if batch_size_i2i > 0: + if image_vocab_size <= 0: + warnings.warn("i2i encountered non-positive image vocab size; skipping loss.") + loss_i2i = zero_loss + else: + start, end = current_idx, current_idx + batch_size_i2i + logits_i2i = logits[start:end] + labels_i2i = labels[start:end] + effective_vocab = image_vocab_end - image_vocab_start + if effective_vocab <= 0: + warnings.warn("i2i effective image vocab is invalid; skipping loss.") + loss_i2i = zero_loss + else: + logits_slice = logits_i2i[..., image_vocab_start:image_vocab_end] + labels_relative = torch.full_like(labels_i2i, -100) + image_mask = (labels_i2i >= image_vocab_start) & (labels_i2i < image_vocab_end) + if not image_mask.any(): + warnings.warn("i2i labels contain no valid image tokens; skipping loss.") + loss_i2i = zero_loss + else: + labels_relative[image_mask] = labels_i2i[image_mask] - image_vocab_start + loss_i2i = F.cross_entropy( + logits_slice.contiguous().view(-1, effective_vocab), + labels_relative.contiguous().view(-1), + ignore_index=-100, + ) + else: + loss_i2i = zero_loss + current_idx += batch_size_i2i + + # --- LM Loss --- + if batch_size_lm > 0: + start, end = current_idx, current_idx + batch_size_lm + logits_lm, labels_lm = logits[start:end], labels[start:end] + masked_indices_lm = masked_indices[start:end] + selected_logits_lm = logits_lm[masked_indices_lm] + effective_vocab_lm = selected_logits_lm.shape[-1] + if text_vocab_size and text_vocab_size < self.output_size: + effective_vocab_lm = min(text_vocab_size, selected_logits_lm.shape[-1]) + selected_logits_lm = selected_logits_lm[:, :effective_vocab_lm] + loss_lm = F.cross_entropy( + selected_logits_lm.contiguous().view(-1, effective_vocab_lm), + labels_lm[masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + ) / p_mask_lm.to(device)[masked_indices_lm] + + if answer_lengths_lm is not None: + loss_lm = torch.sum(loss_lm / answer_lengths_lm.to(device)[masked_indices_lm]) / logits_lm.shape[0] + else: + loss_lm = loss_lm.sum() / logits_lm.shape[0] + else: + loss_lm = zero_loss + current_idx += batch_size_lm + + # --- MMU Loss --- + if batch_size_mmu > 0: + start, end = current_idx, current_idx + batch_size_mmu + loss_mmu = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_mmu, answer_lengths_mmu, self.output_size, device, + ) + else: + loss_mmu = zero_loss + current_idx += batch_size_mmu + + # --- VID (V2T) Loss --- + if batch_size_v2t > 0: + start, end = current_idx, current_idx + batch_size_v2t + loss_vid = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_vid, answer_lengths_vid, self.output_size, device, + ) + else: + loss_vid = zero_loss + current_idx += batch_size_v2t + + # --- V2S Loss --- + if batch_size_v2s > 0: + start, end = current_idx, current_idx + batch_size_v2s + if ( + t2s_vocab_start is None + or t2s_codebook_size is None + or t2s_special_token_ids is None + ): + warnings.warn("v2s missing t2s vocab configuration; skipping loss.") + loss_v2s = zero_loss + elif answer_lengths_v2s is None or not (answer_lengths_v2s > 0).any(): + warnings.warn("v2s encountered empty answer lengths; skipping loss.") + loss_v2s = zero_loss + else: + eoa_id = t2s_special_token_ids.get('eoa') + eos_id = t2s_special_token_ids.get('eos') + loss_v2s = calculate_t2s_loss( + logits[start:end], + labels[start:end], + masked_indices[start:end], + p_mask_v2s, + answer_lengths_v2s, + t2s_vocab_start, + t2s_codebook_size, + eoa_id, + eos_id, + device, + ignore_index=-100, + ) + else: + loss_v2s = zero_loss + current_idx += batch_size_v2s + + # --- S2T Loss --- + if batch_size_s2t > 0: + start, end = current_idx, current_idx + batch_size_s2t + loss_s2t = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_s2t, answer_lengths_s2t, self.output_size, device, + ) + else: + loss_s2t = zero_loss + current_idx += batch_size_s2t + + # --- S2S Loss --- + if batch_size_s2s > 0: + start, end = current_idx, current_idx + batch_size_s2s + if ( + t2s_vocab_start is None + or t2s_codebook_size is None + or t2s_special_token_ids is None + or p_mask_s2s is None + or answer_lengths_s2s is None + ): + warnings.warn("s2s missing t2s vocab configuration or masks; skipping loss.") + loss_s2s = zero_loss + elif not (answer_lengths_s2s > 0).any(): + warnings.warn("s2s encountered empty answer lengths; skipping loss.") + loss_s2s = zero_loss + else: + eoa_id = t2s_special_token_ids.get('eoa') + eos_id = t2s_special_token_ids.get('eos') + loss_s2s = calculate_t2s_loss( + logits[start:end], + labels[start:end], + masked_indices[start:end], + p_mask_s2s, + answer_lengths_s2s, + t2s_vocab_start, + t2s_codebook_size, + eoa_id, + eos_id, + device, + ignore_index=-100, + ) + else: + loss_s2s = zero_loss + current_idx += batch_size_s2s + + # --- T2S Loss --- + if batch_size_t2s > 0: + start, end = current_idx, current_idx + batch_size_t2s + if ( + t2s_vocab_start is not None + and t2s_codebook_size is not None + and t2s_special_token_ids is not None + ): + eoa_id = t2s_special_token_ids.get('eoa') + eos_id = t2s_special_token_ids.get('eos') + else: + eoa_id = eos_id = None + + if eoa_id is not None and eos_id is not None: + loss_t2s = calculate_t2s_loss( + logits[start:end], + labels[start:end], + masked_indices[start:end], + p_mask_t2s, + answer_lengths_t2s, + t2s_vocab_start, + t2s_codebook_size, + eoa_id, + eos_id, + device, + ignore_index=-100, + ) + else: + loss_t2s = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_t2s, answer_lengths_t2s, self.output_size, device + ) + else: + loss_t2s = zero_loss + current_idx += batch_size_t2s + + return logits, loss_t2i, loss_i2i, loss_lm, loss_mmu, loss_vid, loss_v2s, loss_s2t, loss_s2s, loss_t2s + + def forward_process_with_r2i( + self, + input_ids, + labels, + t2i_masks=None, + max_seq_length=128, + batch_size_t2i=0, + batch_size_lm=0, + batch_size_mmu=0, + batch_size_r2i=0, + p_mask_lm=None, + p_mask_mmu=None, + p_mask_r2i=None, + answer_lengths=None, + answer_lengths_lm=None, + answer_lengths_r2i=None, + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + logits = self(input_ids, attention_bias=attention_bias).logits + # logits = self(input_ids).logits + self.output_size = logits.shape[-1] + + if batch_size_t2i == 0: + loss_t2i = torch.tensor(0.0, device=input_ids.device) + else: + # t2i loss + loss_t2i = F.cross_entropy( + logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + ) + + # llada loss + + start_lm = batch_size_t2i + end_lm = start_lm + batch_size_lm + start_mmu = end_lm + end_mmu = start_mmu + batch_size_mmu + start_r2i = end_mmu + end_r2i = start_r2i + batch_size_r2i + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_lm = masked_indices[start_lm:end_lm] + masked_indices_mmu = masked_indices[start_mmu:end_mmu] + masked_indices_r2i = masked_indices[start_r2i:end_r2i] + + p_mask_lm = p_mask_lm.to(masked_indices_lm.device) + p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device) + p_mask_r2i = p_mask_r2i.to(masked_indices_r2i.device) + + answer_lengths = answer_lengths.to(masked_indices_mmu.device) + answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device) + answer_lengths_r2i = answer_lengths_r2i.to(masked_indices_r2i.device) + + loss_lm = F.cross_entropy( + logits[start_lm:end_lm][masked_indices_lm].contiguous().view(-1, self.output_size), + labels[start_lm:end_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_lm[masked_indices_lm] + + if answer_lengths_lm is not None: + loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[start_lm:end_lm].shape[0]) + else: + loss_lm = loss_lm.sum() / (logits[start_lm:end_lm].shape[0] * logits[start_lm:end_lm].shape[1]) + + loss_mmu = F.cross_entropy( + logits[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1, self.output_size), + labels[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_mmu[masked_indices_mmu] + loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[start_mmu:end_mmu].shape[0]) + + loss_r2i = F.cross_entropy( + logits[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1, self.output_size), + labels[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_r2i[masked_indices_r2i] + loss_r2i = torch.sum(loss_r2i/answer_lengths_r2i[masked_indices_r2i]) / (logits[start_r2i:end_r2i].shape[0]) + + return logits, loss_t2i, loss_lm, loss_mmu, loss_r2i + + def forward_t2i( + self, + input_ids, + labels, + batch_size_t2i=0, + max_seq_length=128, + t2i_masks=None + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + logits = self(input_ids, attention_bias=attention_bias).logits + # logits = self(input_ids).logits + self.output_size = logits.shape[-1] + + # print(f"logits shape: {logits.shape}") B, 359, vocab_size + + loss_t2i = F.cross_entropy( + logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + ) + + return loss_t2i + + # Temp + def forward_i2i(self, input_ids, attention_mask, labels): + """ + Forward pass for the I2I task. + """ + outputs = self( + input_ids=input_ids, + attention_mask=attention_mask + ) + logits = outputs.logits + + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + ignore_index=-100 + ) + + return logits, loss + + # Temp + def forward_s2t( + self, + input_ids, + labels, + batch_size_s2t=0, + max_seq_length=128, + p_mask_s2t=None, + answer_lengths=None, + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_s2t = masked_indices[-batch_size_s2t:] + p_mask_s2t = p_mask_s2t.to(masked_indices_s2t.device) + answer_lengths = answer_lengths.to(masked_indices_s2t.device) + + loss_s2t = F.cross_entropy( + logits[-batch_size_s2t:][masked_indices_s2t].contiguous().view(-1, self.output_size), + labels[-batch_size_s2t:][masked_indices_s2t].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_s2t[masked_indices_s2t] + loss_s2t = torch.sum(loss_s2t/answer_lengths[masked_indices_s2t]) / (logits[-batch_size_s2t:].shape[0]) + + return logits, loss_s2t + + def forward_t2s( + self, + input_ids, + labels, + batch_size_t2s=0, + max_seq_length=128, + p_mask_t2s=None, + answer_lengths=None, + ): + """ + Forward pass for text-to-speech (T2S) diffusion LM training. + + Args: + input_ids: (B, L) Input token IDs (text + [MASK]*len(speech)). + labels: (B, L) Target speech codebook token IDs. + batch_size_t2s: Batch size for t2s task (for multitask batches). + max_seq_length: Prompt(text) źøøģ“ + p_mask_t2s: (B, L) Mask probability per position (optional). + answer_lengths: (B,) 각 row별 target length (optional). + Returns: + logits, loss_t2s + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_t2s = masked_indices[-batch_size_t2s:] + p_mask_t2s = p_mask_t2s.to(masked_indices_t2s.device) + answer_lengths = answer_lengths.to(masked_indices_t2s.device) + + loss_t2s = F.cross_entropy( + logits[-batch_size_t2s:][masked_indices_t2s].contiguous().view(-1, self.output_size), + labels[-batch_size_t2s:][masked_indices_t2s].contiguous().view(-1), + ignore_index=-100, reduction='none' + ) / p_mask_t2s[masked_indices_t2s] + loss_t2s = torch.sum(loss_t2s / answer_lengths[masked_indices_t2s]) / logits[-batch_size_t2s:].shape[0] + + return logits, loss_t2s + + def forward_v2t( + self, + input_ids, + labels, + batch_size_v2t=0, + max_seq_length=128, + p_mask_v2t=None, + answer_lengths=None, + ): + """ + video-to-text (V2T) diffusion LM training. + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_v2t = masked_indices[:batch_size_v2t] + p_mask_v2t = p_mask_v2t.to(masked_indices_v2t.device) + answer_lengths = answer_lengths.to(masked_indices_v2t.device) + + loss_v2t = F.cross_entropy( + logits[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1, self.output_size), + labels[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1), + ignore_index=-100, + reduction='none' + ) / p_mask_v2t[masked_indices_v2t] + loss_v2t = torch.sum(loss_v2t / answer_lengths[masked_indices_v2t]) / (logits[:batch_size_v2t].shape[0]) + return logits, loss_v2t + + def forward_v2t_encoder( + self, + input_ids, + labels, + batch_size_v2t=0, + max_seq_length=128, + p_mask_v2t=None, + answer_lengths=None, + ): + """ + video-to-text (V2T) diffusion LM training. + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + input_embeddings = super().model.transformer.wte(input_ids) + + + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_v2t = masked_indices[:batch_size_v2t] + p_mask_v2t = p_mask_v2t.to(masked_indices_v2t.device) + answer_lengths = answer_lengths.to(masked_indices_v2t.device) + + loss_v2t = F.cross_entropy( + logits[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1, self.output_size), + labels[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1), + ignore_index=-100, + reduction='none' + ) / p_mask_v2t[masked_indices_v2t] + loss_v2t = torch.sum(loss_v2t / answer_lengths[masked_indices_v2t]) / (logits[:batch_size_v2t].shape[0]) + return logits, loss_v2t + + def forward_v2s( + self, + input_ids, + labels, + batch_size_v2s=0, + max_seq_length: int = 128, + p_mask_v2s=None, + answer_lengths=None, + t2s_vocab_start: Optional[int] = None, + t2s_codebook_size: Optional[int] = None, + t2s_special_token_ids: Optional[Dict[str, int]] = None, + ): + """ + # video-to-speech (V2S) diffusion LM training. + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_v2s = masked_indices[:batch_size_v2s] + if batch_size_v2s == 0: + return logits, torch.tensor(0.0, device=input_ids.device) + + p_mask_v2s = p_mask_v2s.to(masked_indices_v2s.device) + answer_lengths = answer_lengths.to(masked_indices_v2s.device) + + if ( + t2s_vocab_start is not None + and t2s_codebook_size is not None + and t2s_special_token_ids is not None + ): + eoa_id = t2s_special_token_ids.get('eoa') + eos_id = t2s_special_token_ids.get('eos') + else: + eoa_id = eos_id = None + + loss_v2s = calculate_t2s_loss( + logits[:batch_size_v2s], + labels[:batch_size_v2s], + masked_indices_v2s, + p_mask_v2s, + answer_lengths, + t2s_vocab_start, + t2s_codebook_size, + eoa_id, + eos_id, + input_ids.device, + ignore_index=-100, + ) + return logits, loss_v2s + + + # def forward_i2i(self, input_ids, attention_mask, labels, max_prompt_length): + # """ + # Forward pass for the I2I task. + # """ + # outputs = self( + # input_ids=input_ids, + # attention_mask=attention_mask + # ) + # logits = outputs.logits + + # logits_for_loss = logits[:, max_prompt_length:].contiguous() + # labels_for_loss = labels[:, max_prompt_length:].contiguous() + + # loss = F.cross_entropy( + # logits_for_loss.view(-1, logits_for_loss.size(-1)), + # labels_for_loss.view(-1), + # ignore_index=-100 + # ) + + # return logits, loss + + @torch.no_grad() + def mmu_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + # print(f"num_blocks: {num_blocks}, steps: {steps}") + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}") + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + + # logits = logits[:, -1, :] / temperature + # # optionally crop the logits to only the top k options + # if top_k is not None: + # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # logits[logits < v[:, [-1]]] = -float('Inf') + # # apply softmax to convert logits to (normalized) probabilities + # probs = F.softmax(logits, dim=-1) + # # sample from the distribution + # idx_next = torch.multinomial(probs, num_samples=1) + # result.append(idx_next[0][0]) + # # append sampled index to the running sequence and continue + # if self.config.w_clip_vit: + # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next) + # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) + # else: + # idx = torch.cat((idx, idx_next), dim=1) + + # if eot_token is not None and idx_next.cpu() == eot_token: + # break + + return x + + + @torch.no_grad() + def s2t_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + # print(f"num_blocks: {num_blocks}, steps: {steps}") + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}") + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + + # logits = logits[:, -1, :] / temperature + # # optionally crop the logits to only the top k options + # if top_k is not None: + # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # logits[logits < v[:, [-1]]] = -float('Inf') + # # apply softmax to convert logits to (normalized) probabilities + # probs = F.softmax(logits, dim=-1) + # # sample from the distribution + # idx_next = torch.multinomial(probs, num_samples=1) + # result.append(idx_next[0][0]) + # # append sampled index to the running sequence and continue + # if self.config.w_clip_vit: + # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next) + # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) + # else: + # idx = torch.cat((idx, idx_next), dim=1) + + # if eot_token is not None and idx_next.cpu() == eot_token: + # break + + return x + + @torch.no_grad() + def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + if eot_token is not None: + last_token_index_in_current_block = idx.shape[1] + (num_block + 1) * block_length - 1 + if last_token_index_in_current_block < x.shape[1]: + tokens_at_block_end = x[:, last_token_index_in_current_block] + if torch.all(tokens_at_block_end == eot_token): + break + return x + + @torch.no_grad() + def t2i_generate_decoding_stepwise( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + vq_model = None, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + current_image_vq_indices = sampled_ids.clone() + # print(f"current_image_vq_indices: {current_image_vq_indices}") + current_image_vq_indices = torch.clamp(current_image_vq_indices, 0, 8192 - 1) + current_image = vq_model.decode_code(current_image_vq_indices) + images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = Image.fromarray(images[0]) + yield pil_images, f"Step {step + 1}/{timesteps}" + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + + return sampled_ids + + +AutoConfig.register("omada", OMadaConfig) +AutoModelForCausalLM.register(OMadaConfig, OMadaModelLM) +AutoModel.register(OMadaConfig, OMadaModelLM) diff --git a/MMaDA/models/modeling_omada.py.bak b/MMaDA/models/modeling_omada.py.bak new file mode 100644 index 0000000000000000000000000000000000000000..7d2945b7ac8b61b134993d9a30fe39856cbeaf94 --- /dev/null +++ b/MMaDA/models/modeling_omada.py.bak @@ -0,0 +1,1930 @@ +from __future__ import annotations + +import logging +import math +import sys +from abc import abstractmethod +from collections import defaultdict +from functools import partial +from typing import ( + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + cast, +) +from dataclasses import fields +from typing import List, Optional, Tuple, Union +import numpy as np +import torch +import torch.backends.cuda +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.auto import AutoModel, AutoConfig, AutoModelForCausalLM +from transformers.cache_utils import Cache +from PIL import Image +from .configuration_llada import ( + LLaDAConfig, + StrEnum, + InitFnType, + ActivationType, + BlockType, + LayerNormType, + ModelConfig, + ActivationCheckpointingStrategy, +) + +from .modeling_llada import LLaDAModelLM +from .modeling_video_encoder import VideoEncoder +from .sampling import cosine_schedule, mask_by_random_topk +from transformers import PretrainedConfig + +def calculate_mmu_style_loss(logits_batch, labels_batch, masked_indices_batch, p_mask, answer_lengths, output_size, device): + if logits_batch.shape[0] == 0: + return logits_batch.new_zeros(()) + + loss = F.cross_entropy( + logits_batch[masked_indices_batch].contiguous().view(-1, output_size), + labels_batch[masked_indices_batch].contiguous().view(-1), ignore_index=-100, reduction='none' + ) / p_mask.to(device)[masked_indices_batch] + + loss = torch.sum(loss / answer_lengths.to(device)[masked_indices_batch]) / logits_batch.shape[0] + return loss + + +def calculate_t2s_loss( + logits_batch, + labels_batch, + masked_indices_batch, + p_mask, + answer_lengths, + vocab_start, + codebook_size, + eoa_token_id, + eos_token_id, + device, + ignore_index=-100, +): + if logits_batch.shape[0] == 0: + return logits_batch.new_zeros(()) + + selected_logits = logits_batch[masked_indices_batch] + selected_labels = labels_batch[masked_indices_batch] + + if selected_logits.shape[0] == 0: + return logits_batch.new_zeros(()) + + speech_logits = selected_logits[:, vocab_start : vocab_start + codebook_size] + eoa_logits = selected_logits[:, eoa_token_id : eoa_token_id + 1] + eos_logits = selected_logits[:, eos_token_id : eos_token_id + 1] + combined_logits = torch.cat([speech_logits, eoa_logits, eos_logits], dim=-1) + + relative_labels = torch.full_like(selected_labels, ignore_index) + audio_mask = (selected_labels >= vocab_start) & (selected_labels < vocab_start + codebook_size) + relative_labels[audio_mask] = selected_labels[audio_mask] - vocab_start + relative_labels[selected_labels == eoa_token_id] = codebook_size + relative_labels[selected_labels == eos_token_id] = codebook_size + 1 + + loss = F.cross_entropy( + combined_logits, + relative_labels, + ignore_index=ignore_index, + reduction='none' + ) / p_mask.to(device)[masked_indices_batch] + + loss = torch.sum(loss / answer_lengths.to(device)[masked_indices_batch]) / logits_batch.shape[0] + return loss + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + +class OMadaConfig(PretrainedConfig): + model_type = "omada" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + allowed_keys = [ + "vocab_size", + "llm_vocab_size", + "llm_model_path", + "codebook_size", + "num_vq_tokens", + "num_new_special_tokens", + "gradient_checkpointing", + "new_vocab_size", + ] + + for key in allowed_keys: + if key in kwargs: + setattr(self, key, kwargs[key]) + + + +class OMadaModelLM(LLaDAModelLM): + config_class = OMadaConfig + base_model_prefix = "model" + def __init__(self, config: OMadaConfig, *args, **kwargs): + print(f"Initializing OMadaModelLM with config: {config}") + super().__init__(config, *args, **kwargs) + + # # resize token embeddings + # print(f"Resizing token embeddings to {config.new_vocab_size}") + # self.resize_token_embeddings(config.new_vocab_size) + + @torch.no_grad() + def t2i_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + @torch.no_grad() + def t2s_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, + guidance_scale=0, + noise_schedule=None, + generator: torch.Generator = None, + config=None, + seq_len=256, + mask_token_id=126336, + **kwargs, + ): + uni_prompting = kwargs.get("uni_prompting", None) + if uni_prompting is None: + raise ValueError("uni_prompting object must be provided in kwargs.") + + eoa_token_id = uni_prompting.sptids_dict['<|eoa|>'][0].item() + eos_token_id = uni_prompting.text_tokenizer.eos_token_id + + num_vq_tokens = (input_ids == mask_token_id).sum(dim=-1).max().item() + if num_vq_tokens == 0: + raise ValueError("No mask tokens found in input_ids.") + + speech_vocab_start_idx = len(uni_prompting.text_tokenizer) + 8192 + speech_vocab_end_idx = speech_vocab_start_idx + 4096 + + # VQ Codes: 0 ~ 4095 + # EOA: 4096 + # EOS: 4097 + vq_code_relative_eoa_id = 4096 + vq_code_relative_eos_id = 4097 + + input_ids_relative = input_ids[:, -(num_vq_tokens):].clone() + input_ids_relative = torch.where( + input_ids_relative == mask_token_id, + mask_token_id, + input_ids_relative - speech_vocab_start_idx + ) + + if uncond_input_ids is not None: + start_gen_idx = (uncond_input_ids[0] == uni_prompting.sptids_dict['<|soa|>'][0].item()).nonzero(as_tuple=True)[0][0].item() + 1 + uncond_prefix = uncond_input_ids[:, :start_gen_idx] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat([uncond_prefix, input_ids[:, start_gen_idx:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + + logits_vq = logits[:, -(num_vq_tokens):, speech_vocab_start_idx:speech_vocab_end_idx] + logits_eoa = logits[:, -(num_vq_tokens):, eoa_token_id:eoa_token_id+1] + logits_eos = logits[:, -(num_vq_tokens):, eos_token_id:eos_token_id+1] + + combined_logits = torch.cat([logits_vq, logits_eoa, logits_eos], dim=-1) + + probs = combined_logits.softmax(dim=-1) + sampled = probs.reshape(-1, combined_logits.size(-1)) + + sampled_ids_relative = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*combined_logits.shape[:-1]) + + unknown_map = input_ids_relative == mask_token_id + + sampled_ids_relative = torch.where(unknown_map, sampled_ids_relative, input_ids_relative) + + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio, device=logits.device)) + + selected_probs = torch.gather(probs, -1, sampled_ids_relative.long()[..., None]).squeeze(-1) + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + mask_len = torch.max( + torch.tensor([1], device=logits.device), + torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + + input_ids[:, -(num_vq_tokens):] = torch.where( + masking, + mask_token_id, + torch.where( + sampled_ids_relative == vq_code_relative_eos_id, + eos_token_id, + torch.where( + sampled_ids_relative == vq_code_relative_eoa_id, + eoa_token_id, + sampled_ids_relative + speech_vocab_start_idx + ) + ) + ) + + input_ids_relative = torch.where(masking, mask_token_id, sampled_ids_relative) + + # print("--- Generation Loop Finished ---") + # print("Final sequence BEFORE post-processing (relative IDs):") + # print(input_ids_relative[0]) + # print(f"Shape: {input_ids_relative.shape}") + # print("---------------------------------") + + final_output_ids = [] + for i in range(input_ids_relative.shape[0]): + seq = input_ids_relative[i] + + eoa_indices = (seq >= vq_code_relative_eoa_id).nonzero(as_tuple=True)[0] + + if eoa_indices.numel() > 0: + first_eoa_idx = eoa_indices[0] + seq = seq[:first_eoa_idx] + + valid_tokens = seq[seq != mask_token_id] + + final_output_ids.append(valid_tokens) + + return final_output_ids + + @torch.no_grad() + def t2s_generate_mmu_like( + self, + input_ids: torch.LongTensor, + max_new_tokens: Optional[int] = None, + steps: int = 256, + block_length: int = 128, + temperature: float = 0.0, + cfg_scale: float = 0.0, + mask_token_id: int = 126336, + attention_mask: Optional[torch.LongTensor] = None, + uni_prompting=None, + codebook_size: Optional[int] = None, + audio_codebook_size: int = 4096, + ): + """ + Generate speech tokens with MMU-style block-wise refinement. + Assumes the speech region within ``input_ids`` is contiguous and filled with ``mask_token_id`` + prior to generation. + """ + + if uni_prompting is None: + raise ValueError("uni_prompting must be provided") + if block_length <= 0: + raise ValueError("block_length must be positive") + + batch_size, seq_len = input_ids.shape + device = input_ids.device + + mask_positions_full = (input_ids == mask_token_id) + if not mask_positions_full.any(): + raise ValueError("No mask tokens detected for T2S generation") + + mask_cols = torch.where(mask_positions_full[0])[0] + speech_region_start = mask_cols[0].item() + speech_region_len = mask_cols.numel() + + mask_counts = mask_positions_full.sum(dim=1) + if not torch.all(mask_counts == mask_counts[0]): + raise ValueError("All batch items must contain the same number of masked speech tokens for MMU-like generation") + + if max_new_tokens is None: + max_new_tokens = speech_region_len + else: + max_new_tokens = min(max_new_tokens, speech_region_len) + + block_length = max(1, min(block_length, max_new_tokens)) + num_blocks = math.ceil(max_new_tokens / block_length) + inner_steps = max(1, steps // num_blocks) + + codebook_base = codebook_size if codebook_size is not None else getattr(self.config, "codebook_size", 8192) + speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_base + speech_vocab_end = speech_vocab_start + audio_codebook_size + + eoa_token_id = uni_prompting.sptids_dict['<|eoa|>'][0].item() + eos_token_id = uni_prompting.text_tokenizer.eos_token_id + vq_code_relative_eoa_id = audio_codebook_size + vq_code_relative_eos_id = audio_codebook_size + 1 + + work = input_ids.clone() + + attention_bias = None + if attention_mask is not None: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + + speech_indices = mask_cols[:max_new_tokens] + + for block_idx in range(num_blocks): + block_start = block_idx * block_length + block_end = min(block_start + block_length, max_new_tokens) + curr_indices = speech_indices[block_start:block_end] + if curr_indices.numel() == 0: + continue + + block_mask = mask_positions_full[:, curr_indices] + num_transfer_tokens = get_num_transfer_tokens(block_mask, inner_steps) + + for inner_step in range(inner_steps): + if cfg_scale > 0.0: + un_cond = work.clone() + un_cond[:, speech_indices] = mask_token_id + stacked = torch.cat([work, un_cond], dim=0) + if attention_bias is not None: + att_bias = torch.cat([attention_bias, attention_bias], dim=0) + else: + att_bias = None + logits = self(stacked, attention_bias=att_bias).logits + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits) + else: + logits = self(work, attention_bias=attention_bias).logits + + logits_block = logits.index_select(1, curr_indices.to(device)) + logits_vq = logits_block[:, :, speech_vocab_start:speech_vocab_end] + logits_eoa = logits_block[:, :, eoa_token_id:eoa_token_id + 1] + logits_eos = logits_block[:, :, eos_token_id:eos_token_id + 1] + + combined_logits = torch.cat([logits_vq, logits_eoa, logits_eos], dim=-1) + if temperature > 0.0: + combined_logits = combined_logits / max(temperature, 1e-5) + probs = F.softmax(combined_logits, dim=-1) + + sampled = torch.multinomial( + probs.view(-1, probs.size(-1)), 1 + ).view(batch_size, curr_indices.numel()) + + selected_probs = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) + + eos_tensor = sampled.new_full(sampled.shape, eos_token_id) + eoa_tensor = sampled.new_full(sampled.shape, eoa_token_id) + sampled_absolute = torch.where( + sampled == vq_code_relative_eos_id, + eos_tensor, + torch.where( + sampled == vq_code_relative_eoa_id, + eoa_tensor, + sampled + speech_vocab_start + ) + ) + + current_block_vals = work.index_select(1, curr_indices) + mask_current = current_block_vals == mask_token_id + + confidence = torch.where( + mask_current, + selected_probs, + torch.full_like(selected_probs, float('-inf')) + ) + + finalize = torch.zeros_like(mask_current, dtype=torch.bool) + for b in range(batch_size): + available = mask_current[b].sum().item() + if available == 0: + continue + transfer = min(int(num_transfer_tokens[b, inner_step].item()), available) + if transfer <= 0: + continue + _, idxs = torch.topk(confidence[b], k=transfer, largest=True) + finalize[b, idxs] = True + + mask_fill = sampled_absolute.new_full(sampled_absolute.shape, mask_token_id) + updates = torch.where(finalize, sampled_absolute, mask_fill) + new_block = torch.where(mask_current, updates, current_block_vals) + + work[:, curr_indices] = new_block + mask_positions_full[:, curr_indices] = new_block == mask_token_id + + if not mask_positions_full[:, curr_indices].any(): + break + + final_outputs = [] + audio_slice = slice(speech_region_start, speech_region_start + speech_region_len) + audio_region = work[:, audio_slice] + + for seq in audio_region: + mask_tensor = seq.new_full(seq.shape, mask_token_id) + rel_eoa = seq.new_full(seq.shape, vq_code_relative_eoa_id) + rel_eos = seq.new_full(seq.shape, vq_code_relative_eos_id) + relative = torch.where( + seq == mask_token_id, + mask_tensor, + torch.where( + seq == eoa_token_id, + rel_eoa, + torch.where( + seq == eos_token_id, + rel_eos, + seq - speech_vocab_start + ) + ) + ) + + eoa_positions = (relative >= vq_code_relative_eoa_id).nonzero(as_tuple=True)[0] + if eoa_positions.numel() > 0: + relative = relative[:eoa_positions[0]] + + final_outputs.append(relative[relative != mask_token_id]) + + return final_outputs + + @torch.no_grad() + def t2s_fixed_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, + guidance_scale=0, + noise_schedule=None, + generator: torch.Generator = None, + config=None, + seq_len=256, + mask_token_id=126336, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens - 8192) + + # for classifier-free guidance + if uncond_input_ids is not None: + start_gen_idx = (uncond_input_ids[0] == uni_prompting.sptids_dict['<|soa|>'][0].item()).nonzero(as_tuple=True)[0][0].item() + 1 + uncond_prefix = uncond_input_ids[:, :start_gen_idx] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, start_gen_idx:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 : len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 + 4096] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 : len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 + 4096] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens + 8192) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + @torch.no_grad() + def i2i_generate( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + return sampled_ids + + # def forward_process( + # self, + # input_ids, + # labels, + # batch_size_t2i=0, + # batch_size_lm=0, + # batch_size_mmu=0, + # batch_size_v2t=0, + # batch_size_s2t=0, + # batch_size_t2s=0, + # max_seq_length=128, + # p_mask_lm=None, + # p_mask_mmu=None, + # p_mask_vid=None, + # p_mask_s2t=None, + # p_mask_t2s=None, + # answer_lengths=None, + # t2i_masks=None, + # answer_lengths_lm=None + # ): + # # attention bias, True for batch_size, 1, seq_len, seq_len + # attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + # attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + # attention_bias[:batch_size_t2i] = attention_bias_t2i + # logits = self(input_ids, attention_bias=attention_bias).logits + # self.output_size = logits.shape[-1] + + # if batch_size_t2i == 0: + # loss_t2i = torch.tensor(0.0, device=input_ids.device) + # else: + # loss_t2i = F.cross_entropy( + # logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + # labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + # ) + + # masked_indices = input_ids == self.config.mask_token_id + # masked_indices_lm = masked_indices[batch_size_t2i:batch_size_t2i + batch_size_lm] + # masked_indices_mmu = masked_indices[-batch_size_mmu:] + # p_mask_lm = p_mask_lm.to(masked_indices_lm.device) + # p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device) + # answer_lengths = answer_lengths.to(masked_indices_mmu.device) + # loss_lm = F.cross_entropy( + # logits[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1, self.output_size), + # labels[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + # )/p_mask_lm[masked_indices_lm] + + # if answer_lengths_lm is not None: + # loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0]) + # else: + # loss_lm = loss_lm.sum() / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0] * logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[1]) + + # loss_mmu = F.cross_entropy( + # logits[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1, self.output_size), + # labels[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none' + # )/p_mask_mmu[masked_indices_mmu] + # loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[-batch_size_mmu:].shape[0]) + + # return logits, loss_t2i, loss_lm, loss_mmu + + # def forward_process( + # self, + # input_ids, + # labels, + # batch_size_t2i=0, + # batch_size_lm=0, + # batch_size_mmu=0, + # batch_size_v2t=0, + # batch_size_s2t=0, + # batch_size_t2s=0, + # max_seq_length=128, + # p_mask_lm=None, + # p_mask_mmu=None, + # p_mask_vid=None, + # p_mask_s2t=None, + # p_mask_t2s=None, + # answer_lengths_lm=None, + # answer_lengths_mmu=None, + # answer_lengths_vid=None, + # answer_lengths_s2t=None, + # answer_lengths_t2s=None, + # t2i_masks=None, + # t2s_vocab_start=None, + # t2s_codebook_size=None, + # t2s_special_token_ids=None + # ): + # # --- 1. Attention Bias Setup (no changes) --- + # attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + # if batch_size_t2i > 0 and t2i_masks is not None: + # attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + # attention_bias[:batch_size_t2i] = attention_bias_t2i + + # # --- 2. Model Forward Pass (no changes) --- + # logits = self(input_ids, attention_bias=attention_bias).logits + # self.output_size = logits.shape[-1] + + # # --- 3. Loss Calculation --- + # device = input_ids.device + # zero_loss = torch.tensor(0.0, device=device) + + # # Calculate masked indices for the entire batch + # masked_indices = (input_ids == self.config.mask_token_id) + + # current_idx = 0 + + # # --- T2I Loss --- + # if batch_size_t2i > 0: + # loss_t2i = F.cross_entropy( + # logits[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + # labels[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + # ) + # else: + # loss_t2i = zero_loss + # current_idx += batch_size_t2i + + # # --- LM Loss --- + # if batch_size_lm > 0: + # start, end = current_idx, current_idx + batch_size_lm + # logits_lm, labels_lm = logits[start:end], labels[start:end] + # masked_indices_lm = masked_indices[start:end] + + # loss_lm = F.cross_entropy( + # logits_lm[masked_indices_lm].contiguous().view(-1, self.output_size), + # labels_lm[masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + # ) / p_mask_lm.to(device)[masked_indices_lm] + + # if answer_lengths_lm is not None: + # loss_lm = torch.sum(loss_lm / answer_lengths_lm.to(device)[masked_indices_lm]) / logits_lm.shape[0] + # else: + # loss_lm = loss_lm.sum() / logits_lm.shape[0] + # else: + # loss_lm = zero_loss + # current_idx += batch_size_lm + + # # --- MMU Loss --- + # if batch_size_mmu > 0: + # start, end = current_idx, current_idx + batch_size_mmu + # loss_mmu = calculate_mmu_style_loss( + # logits[start:end], labels[start:end], masked_indices[start:end], + # p_mask_mmu, answer_lengths_mmu, self.output_size, device + # ) + # else: + # loss_mmu = zero_loss + # current_idx += batch_size_mmu + + # # --- VID (V2T) Loss --- + # if batch_size_v2t > 0: + # start, end = current_idx, current_idx + batch_size_v2t + # loss_vid = calculate_mmu_style_loss( + # logits[start:end], labels[start:end], masked_indices[start:end], + # p_mask_vid, answer_lengths_vid, self.output_size, device + # ) + # else: + # loss_vid = zero_loss + # current_idx += batch_size_v2t + + # # --- S2T Loss --- + # if batch_size_s2t > 0: + # start, end = current_idx, current_idx + batch_size_s2t + # loss_s2t = calculate_mmu_style_loss( + # logits[start:end], labels[start:end], masked_indices[start:end], + # p_mask_s2t, answer_lengths_s2t, self.output_size, device + # ) + # else: + # loss_s2t = zero_loss + # current_idx += batch_size_s2t + + # # --- T2S Loss --- + # if batch_size_t2s > 0: + # start, end = current_idx, current_idx + batch_size_t2s + # if ( + # t2s_vocab_start is not None + # and t2s_codebook_size is not None + # and t2s_special_token_ids is not None + # ): + # eoa_id = t2s_special_token_ids.get('eoa') + # eos_id = t2s_special_token_ids.get('eos') + # else: + # eoa_id = eos_id = None + + # if eoa_id is not None and eos_id is not None: + # loss_t2s = calculate_t2s_loss( + # logits[start:end], + # labels[start:end], + # masked_indices[start:end], + # p_mask_t2s, + # answer_lengths_t2s, + # t2s_vocab_start, + # t2s_codebook_size, + # eoa_id, + # eos_id, + # device, + # ignore_index=-100, + # ) + # else: + # loss_t2s = calculate_mmu_style_loss( + # logits[start:end], labels[start:end], masked_indices[start:end], + # p_mask_t2s, answer_lengths_t2s, self.output_size, device + # ) + # else: + # loss_t2s = zero_loss + # current_idx += batch_size_t2s + + # return logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s + + def forward_process( + self, + input_ids, + labels, + batch_size_t2i=0, + batch_size_i2i=0, + batch_size_lm=0, + batch_size_mmu=0, + batch_size_v2t=0, + batch_size_s2t=0, + batch_size_s2s=0, + batch_size_t2s=0, + max_seq_length=128, + p_mask_lm=None, + p_mask_mmu=None, + p_mask_vid=None, + p_mask_s2t=None, + p_mask_s2s=None, + p_mask_t2s=None, + answer_lengths_lm=None, + answer_lengths_mmu=None, + answer_lengths_vid=None, + answer_lengths_s2t=None, + answer_lengths_s2s=None, + answer_lengths_t2s=None, + t2i_masks=None, + attention_masks_i2i=None, + t2s_vocab_start=None, + t2s_codebook_size=None, + t2s_special_token_ids=None + ): + # --- 1. Attention Bias Setup (no changes) --- + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + if batch_size_t2i > 0 and t2i_masks is not None: + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + + if batch_size_i2i > 0 and attention_masks_i2i is not None: + start_i2i = batch_size_t2i + end_i2i = start_i2i + batch_size_i2i + attn_mask = attention_masks_i2i.to(input_ids.device) + if attn_mask.dtype != torch.bool: + attn_mask = attn_mask.bool() + attention_bias_i2i = (attn_mask[:, :, None] & attn_mask[:, None, :]).unsqueeze(1) + attention_bias[start_i2i:end_i2i] = attention_bias_i2i + + # --- 2. Model Forward Pass (no changes) --- + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + # --- 3. Loss Calculation --- + device = input_ids.device + zero_loss = torch.tensor(0.0, device=device) + + # Calculate masked indices for the entire batch + masked_indices = (input_ids == self.config.mask_token_id) + + text_vocab_size = getattr(self.config, "llm_vocab_size", None) + if text_vocab_size is None: + text_vocab_size = getattr(self.config, "vocab_size", logits.shape[-1]) + image_vocab_size = getattr(self.config, "codebook_size", 0) + image_vocab_start = text_vocab_size + image_vocab_end = min(image_vocab_start + image_vocab_size, logits.shape[-1]) + + current_idx = 0 + + # --- T2I Loss --- + if batch_size_t2i > 0: + logits_t2i = logits[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:] + labels_t2i = labels[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:] + if image_vocab_size > 0: + effective_vocab = image_vocab_end - image_vocab_start + if effective_vocab > 0: + logits_slice = logits_t2i[..., image_vocab_start:image_vocab_end] + labels_relative = torch.full_like(labels_t2i, -100) + valid_mask = (labels_t2i >= image_vocab_start) & (labels_t2i < image_vocab_end) + labels_relative[valid_mask] = labels_t2i[valid_mask] - image_vocab_start + loss_t2i = F.cross_entropy( + logits_slice.contiguous().view(-1, effective_vocab), + labels_relative.contiguous().view(-1), + ignore_index=-100, + ) + else: + loss_t2i = F.cross_entropy( + logits_t2i.contiguous().view(-1, self.output_size), + labels_t2i.contiguous().view(-1), + ignore_index=-100, + ) + else: + loss_t2i = F.cross_entropy( + logits_t2i.contiguous().view(-1, self.output_size), + labels_t2i.contiguous().view(-1), + ignore_index=-100, + ) + else: + loss_t2i = zero_loss + current_idx += batch_size_t2i + + # --- I2I Loss --- + if batch_size_i2i > 0: + start, end = current_idx, current_idx + batch_size_i2i + logits_i2i = logits[start:end] + labels_i2i = labels[start:end] + if image_vocab_size > 0: + effective_vocab = image_vocab_end - image_vocab_start + if effective_vocab > 0: + logits_slice = logits_i2i[..., image_vocab_start:image_vocab_end] + labels_relative = torch.full_like(labels_i2i, -100) + image_mask = (labels_i2i >= image_vocab_start) & (labels_i2i < image_vocab_end) + labels_relative[image_mask] = labels_i2i[image_mask] - image_vocab_start + loss_i2i = F.cross_entropy( + logits_slice.contiguous().view(-1, effective_vocab), + labels_relative.contiguous().view(-1), + ignore_index=-100, + ) + else: + loss_i2i = F.cross_entropy( + logits_i2i.view(-1, self.output_size), + labels_i2i.view(-1), + ignore_index=-100, + ) + else: + loss_i2i = F.cross_entropy( + logits_i2i.view(-1, self.output_size), + labels_i2i.view(-1), + ignore_index=-100, + ) + else: + loss_i2i = zero_loss + current_idx += batch_size_i2i + + # --- LM Loss --- + if batch_size_lm > 0: + start, end = current_idx, current_idx + batch_size_lm + logits_lm, labels_lm = logits[start:end], labels[start:end] + masked_indices_lm = masked_indices[start:end] + selected_logits_lm = logits_lm[masked_indices_lm] + effective_vocab_lm = selected_logits_lm.shape[-1] + if text_vocab_size and text_vocab_size < self.output_size: + effective_vocab_lm = min(text_vocab_size, selected_logits_lm.shape[-1]) + selected_logits_lm = selected_logits_lm[:, :effective_vocab_lm] + loss_lm = F.cross_entropy( + selected_logits_lm.contiguous().view(-1, effective_vocab_lm), + labels_lm[masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + ) / p_mask_lm.to(device)[masked_indices_lm] + + if answer_lengths_lm is not None: + loss_lm = torch.sum(loss_lm / answer_lengths_lm.to(device)[masked_indices_lm]) / logits_lm.shape[0] + else: + loss_lm = loss_lm.sum() / logits_lm.shape[0] + else: + loss_lm = zero_loss + current_idx += batch_size_lm + + # --- MMU Loss --- + if batch_size_mmu > 0: + start, end = current_idx, current_idx + batch_size_mmu + loss_mmu = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_mmu, answer_lengths_mmu, self.output_size, device, + vocab_start=0, + vocab_end=text_vocab_size, + ) + else: + loss_mmu = zero_loss + current_idx += batch_size_mmu + + # --- VID (V2T) Loss --- + if batch_size_v2t > 0: + start, end = current_idx, current_idx + batch_size_v2t + loss_vid = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_vid, answer_lengths_vid, self.output_size, device, + vocab_start=0, + vocab_end=text_vocab_size, + ) + else: + loss_vid = zero_loss + current_idx += batch_size_v2t + + # --- S2T Loss --- + if batch_size_s2t > 0: + start, end = current_idx, current_idx + batch_size_s2t + loss_s2t = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_s2t, answer_lengths_s2t, self.output_size, device, + vocab_start=0, + vocab_end=text_vocab_size, + ) + else: + loss_s2t = zero_loss + current_idx += batch_size_s2t + + # --- S2S Loss --- + if batch_size_s2s > 0: + start, end = current_idx, current_idx + batch_size_s2s + if ( + t2s_vocab_start is not None + and t2s_codebook_size is not None + and t2s_special_token_ids is not None + and p_mask_s2s is not None + and answer_lengths_s2s is not None + ): + eoa_id = t2s_special_token_ids.get('eoa') + eos_id = t2s_special_token_ids.get('eos') + else: + eoa_id = eos_id = None + + if eoa_id is not None and eos_id is not None: + loss_s2s = calculate_t2s_loss( + logits[start:end], + labels[start:end], + masked_indices[start:end], + p_mask_s2s, + answer_lengths_s2s, + t2s_vocab_start, + t2s_codebook_size, + eoa_id, + eos_id, + device, + ignore_index=-100, + ) + else: + loss_s2s = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_s2s, answer_lengths_s2s, self.output_size, device + ) + else: + loss_s2s = zero_loss + current_idx += batch_size_s2s + + # --- T2S Loss --- + if batch_size_t2s > 0: + start, end = current_idx, current_idx + batch_size_t2s + if ( + t2s_vocab_start is not None + and t2s_codebook_size is not None + and t2s_special_token_ids is not None + ): + eoa_id = t2s_special_token_ids.get('eoa') + eos_id = t2s_special_token_ids.get('eos') + else: + eoa_id = eos_id = None + + if eoa_id is not None and eos_id is not None: + loss_t2s = calculate_t2s_loss( + logits[start:end], + labels[start:end], + masked_indices[start:end], + p_mask_t2s, + answer_lengths_t2s, + t2s_vocab_start, + t2s_codebook_size, + eoa_id, + eos_id, + device, + ignore_index=-100, + ) + else: + loss_t2s = calculate_mmu_style_loss( + logits[start:end], labels[start:end], masked_indices[start:end], + p_mask_t2s, answer_lengths_t2s, self.output_size, device + ) + else: + loss_t2s = zero_loss + current_idx += batch_size_t2s + + return logits, loss_t2i, loss_i2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_s2s, loss_t2s + + + def forward_process_with_r2i( + self, + input_ids, + labels, + t2i_masks=None, + max_seq_length=128, + batch_size_t2i=0, + batch_size_lm=0, + batch_size_mmu=0, + batch_size_r2i=0, + p_mask_lm=None, + p_mask_mmu=None, + p_mask_r2i=None, + answer_lengths=None, + answer_lengths_lm=None, + answer_lengths_r2i=None, + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + logits = self(input_ids, attention_bias=attention_bias).logits + # logits = self(input_ids).logits + self.output_size = logits.shape[-1] + + if batch_size_t2i == 0: + loss_t2i = torch.tensor(0.0, device=input_ids.device) + else: + # t2i loss + loss_t2i = F.cross_entropy( + logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + ) + + # llada loss + + start_lm = batch_size_t2i + end_lm = start_lm + batch_size_lm + start_mmu = end_lm + end_mmu = start_mmu + batch_size_mmu + start_r2i = end_mmu + end_r2i = start_r2i + batch_size_r2i + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_lm = masked_indices[start_lm:end_lm] + masked_indices_mmu = masked_indices[start_mmu:end_mmu] + masked_indices_r2i = masked_indices[start_r2i:end_r2i] + + p_mask_lm = p_mask_lm.to(masked_indices_lm.device) + p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device) + p_mask_r2i = p_mask_r2i.to(masked_indices_r2i.device) + + answer_lengths = answer_lengths.to(masked_indices_mmu.device) + answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device) + answer_lengths_r2i = answer_lengths_r2i.to(masked_indices_r2i.device) + + loss_lm = F.cross_entropy( + logits[start_lm:end_lm][masked_indices_lm].contiguous().view(-1, self.output_size), + labels[start_lm:end_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_lm[masked_indices_lm] + + if answer_lengths_lm is not None: + loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[start_lm:end_lm].shape[0]) + else: + loss_lm = loss_lm.sum() / (logits[start_lm:end_lm].shape[0] * logits[start_lm:end_lm].shape[1]) + + loss_mmu = F.cross_entropy( + logits[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1, self.output_size), + labels[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_mmu[masked_indices_mmu] + loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[start_mmu:end_mmu].shape[0]) + + loss_r2i = F.cross_entropy( + logits[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1, self.output_size), + labels[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_r2i[masked_indices_r2i] + loss_r2i = torch.sum(loss_r2i/answer_lengths_r2i[masked_indices_r2i]) / (logits[start_r2i:end_r2i].shape[0]) + + return logits, loss_t2i, loss_lm, loss_mmu, loss_r2i + + def forward_t2i( + self, + input_ids, + labels, + batch_size_t2i=0, + max_seq_length=128, + t2i_masks=None + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) + attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) + attention_bias[:batch_size_t2i] = attention_bias_t2i + logits = self(input_ids, attention_bias=attention_bias).logits + # logits = self(input_ids).logits + self.output_size = logits.shape[-1] + + # print(f"logits shape: {logits.shape}") B, 359, vocab_size + + loss_t2i = F.cross_entropy( + logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), + labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, + ) + + return loss_t2i + + # Temp + def forward_i2i(self, input_ids, attention_mask, labels): + """ + Forward pass for the I2I task. + """ + outputs = self( + input_ids=input_ids, + attention_mask=attention_mask + ) + logits = outputs.logits + + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + ignore_index=-100 + ) + + return logits, loss + + # Temp + def forward_s2t( + self, + input_ids, + labels, + batch_size_s2t=0, + max_seq_length=128, + p_mask_s2t=None, + answer_lengths=None, + ): + # attention bias, True for batch_size, 1, seq_len, seq_len + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_s2t = masked_indices[-batch_size_s2t:] + p_mask_s2t = p_mask_s2t.to(masked_indices_s2t.device) + answer_lengths = answer_lengths.to(masked_indices_s2t.device) + + loss_s2t = F.cross_entropy( + logits[-batch_size_s2t:][masked_indices_s2t].contiguous().view(-1, self.output_size), + labels[-batch_size_s2t:][masked_indices_s2t].contiguous().view(-1), ignore_index=-100, reduction='none' + )/p_mask_s2t[masked_indices_s2t] + loss_s2t = torch.sum(loss_s2t/answer_lengths[masked_indices_s2t]) / (logits[-batch_size_s2t:].shape[0]) + + return logits, loss_s2t + + def forward_t2s( + self, + input_ids, + labels, + batch_size_t2s=0, + max_seq_length=128, + p_mask_t2s=None, + answer_lengths=None, + ): + """ + Forward pass for text-to-speech (T2S) diffusion LM training. + + Args: + input_ids: (B, L) Input token IDs (text + [MASK]*len(speech)). + labels: (B, L) Target speech codebook token IDs. + batch_size_t2s: Batch size for t2s task (for multitask batches). + max_seq_length: Prompt(text) źøøģ“ + p_mask_t2s: (B, L) Mask probability per position (optional). + answer_lengths: (B,) 각 row별 target length (optional). + Returns: + logits, loss_t2s + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_t2s = masked_indices[-batch_size_t2s:] + p_mask_t2s = p_mask_t2s.to(masked_indices_t2s.device) + answer_lengths = answer_lengths.to(masked_indices_t2s.device) + + loss_t2s = F.cross_entropy( + logits[-batch_size_t2s:][masked_indices_t2s].contiguous().view(-1, self.output_size), + labels[-batch_size_t2s:][masked_indices_t2s].contiguous().view(-1), + ignore_index=-100, reduction='none' + ) / p_mask_t2s[masked_indices_t2s] + loss_t2s = torch.sum(loss_t2s / answer_lengths[masked_indices_t2s]) / logits[-batch_size_t2s:].shape[0] + + return logits, loss_t2s + + def forward_v2t( + self, + input_ids, + labels, + batch_size_v2t=0, + max_seq_length=128, + p_mask_v2t=None, + answer_lengths=None, + ): + """ + video-to-text (V2T) diffusion LM training. + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_v2t = masked_indices[:batch_size_v2t] + p_mask_v2t = p_mask_v2t.to(masked_indices_v2t.device) + answer_lengths = answer_lengths.to(masked_indices_v2t.device) + + loss_v2t = F.cross_entropy( + logits[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1, self.output_size), + labels[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1), + ignore_index=-100, + reduction='none' + ) / p_mask_v2t[masked_indices_v2t] + loss_v2t = torch.sum(loss_v2t / answer_lengths[masked_indices_v2t]) / (logits[:batch_size_v2t].shape[0]) + return logits, loss_v2t + + def forward_v2t_encoder( + self, + input_ids, + labels, + batch_size_v2t=0, + max_seq_length=128, + p_mask_v2t=None, + answer_lengths=None, + ): + """ + video-to-text (V2T) diffusion LM training. + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + input_embeddings = super().model.transformer.wte(input_ids) + + + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_v2t = masked_indices[:batch_size_v2t] + p_mask_v2t = p_mask_v2t.to(masked_indices_v2t.device) + answer_lengths = answer_lengths.to(masked_indices_v2t.device) + + loss_v2t = F.cross_entropy( + logits[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1, self.output_size), + labels[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1), + ignore_index=-100, + reduction='none' + ) / p_mask_v2t[masked_indices_v2t] + loss_v2t = torch.sum(loss_v2t / answer_lengths[masked_indices_v2t]) / (logits[:batch_size_v2t].shape[0]) + return logits, loss_v2t + + def forward_v2s( + self, + input_ids, + labels, + batch_size_v2s=0, + max_seq_length=128, + p_mask_v2s=None, + answer_lengths=None, + ): + """ + # video-to-speech (V2S) diffusion LM training. + """ + attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) + logits = self(input_ids, attention_bias=attention_bias).logits + self.output_size = logits.shape[-1] + + masked_indices = input_ids == self.config.mask_token_id + masked_indices_v2s = masked_indices[:batch_size_v2s] + p_mask_v2s = p_mask_v2s.to(masked_indices_v2s.device) + answer_lengths = answer_lengths.to(masked_indices_v2s.device) + + loss_v2s = F.cross_entropy( + logits[:batch_size_v2s][masked_indices_v2s].contiguous().view(-1, self.output_size), + labels[:batch_size_v2s][masked_indices_v2s].contiguous().view(-1), + ignore_index=-100, + reduction='none' + ) / p_mask_v2s[masked_indices_v2s] + loss_v2s = torch.sum(loss_v2s / answer_lengths[masked_indices_v2s]) / (logits[:batch_size_v2s].shape[0]) + return logits, loss_v2s + + + # def forward_i2i(self, input_ids, attention_mask, labels, max_prompt_length): + # """ + # Forward pass for the I2I task. + # """ + # outputs = self( + # input_ids=input_ids, + # attention_mask=attention_mask + # ) + # logits = outputs.logits + + # logits_for_loss = logits[:, max_prompt_length:].contiguous() + # labels_for_loss = labels[:, max_prompt_length:].contiguous() + + # loss = F.cross_entropy( + # logits_for_loss.view(-1, logits_for_loss.size(-1)), + # labels_for_loss.view(-1), + # ignore_index=-100 + # ) + + # return logits, loss + + @torch.no_grad() + def mmu_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + # print(f"num_blocks: {num_blocks}, steps: {steps}") + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}") + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + + # logits = logits[:, -1, :] / temperature + # # optionally crop the logits to only the top k options + # if top_k is not None: + # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # logits[logits < v[:, [-1]]] = -float('Inf') + # # apply softmax to convert logits to (normalized) probabilities + # probs = F.softmax(logits, dim=-1) + # # sample from the distribution + # idx_next = torch.multinomial(probs, num_samples=1) + # result.append(idx_next[0][0]) + # # append sampled index to the running sequence and continue + # if self.config.w_clip_vit: + # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next) + # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) + # else: + # idx = torch.cat((idx, idx_next), dim=1) + + # if eot_token is not None and idx_next.cpu() == eot_token: + # break + + return x + + + @torch.no_grad() + def s2t_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + # print(f"num_blocks: {num_blocks}, steps: {steps}") + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) + # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}") + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + + # logits = logits[:, -1, :] / temperature + # # optionally crop the logits to only the top k options + # if top_k is not None: + # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + # logits[logits < v[:, [-1]]] = -float('Inf') + # # apply softmax to convert logits to (normalized) probabilities + # probs = F.softmax(logits, dim=-1) + # # sample from the distribution + # idx_next = torch.multinomial(probs, num_samples=1) + # result.append(idx_next[0][0]) + # # append sampled index to the running sequence and continue + # if self.config.w_clip_vit: + # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next) + # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) + # else: + # idx = torch.cat((idx, idx_next), dim=1) + + # if eot_token is not None and idx_next.cpu() == eot_token: + # break + + return x + + @torch.no_grad() + def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + # print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + try: + device = idx.device + except: + device = input_embeddings.device + + result = [] + batch_size = idx.shape[0] + x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) + x[:, :idx.shape[1]] = idx.clone() + prompt_index = (x != mask_id) + + + assert max_new_tokens % block_length == 0 + num_blocks = max_new_tokens // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.0: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + if eot_token is not None: + last_token_index_in_current_block = idx.shape[1] + (num_block + 1) * block_length - 1 + if last_token_index_in_current_block < x.shape[1]: + tokens_at_block_end = x[:, last_token_index_in_current_block] + if torch.all(tokens_at_block_end == eot_token): + break + return x + + @torch.no_grad() + def t2i_generate_decoding_stepwise( + self, + input_ids: torch.LongTensor = None, + uncond_input_ids: torch.LongTensor = None, + attention_mask=None, + uncond_attention_mask=None, + temperature=1.0, + timesteps=18, # ideal number of steps is 18 in maskgit paper + guidance_scale=0, + noise_schedule=cosine_schedule, + generator: torch.Generator = None, + config=None, + seq_len=1024, + mask_token_id = 126336, + resolution = 512, + codebook_size = 8192, + vq_model = None, + **kwargs, + ): + """ + Generate 1:1 similar to the original MaskGit repo + https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 + """ + + # begin with all image token ids masked + # č®”ē®—ęœ‰å¤šå°‘äøŖmask token + mask_count = (input_ids == mask_token_id).sum().item() + num_vq_tokens = seq_len + num_new_special_tokens = 0 + uni_prompting = kwargs.get("uni_prompting", None) + # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") + input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() + input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) + + # for classifier-free guidance + if uncond_input_ids is not None: + uncond_prefix = uncond_input_ids[:, :resolution + 1] + + for step in range(timesteps): + if uncond_input_ids is not None and guidance_scale > 0: + uncond_input_ids = torch.cat( + [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) + model_input = torch.cat([input_ids, uncond_input_ids]) + attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(model_input, attention_bias=attention_bias).logits + # print(f"logits.shape: {logits.shape}") + cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) + # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + # it seems that muse has a different cfg setting + logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + else: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + logits = self(input_ids, attention_bias=attention_bias).logits + logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] + + # logits: 1, 1024, 8192 + # print(f"logits.shape: {logits.shape}") + probs = logits.softmax(dim=-1) + sampled = probs.reshape(-1, logits.size(-1)) + # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") + sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 + + unknown_map = input_ids_minus_lm_vocab_size == mask_token_id + # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") + sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) + # Defines the mask ratio for the next round. The number to mask out is + current_image_vq_indices = sampled_ids.clone() + # print(f"current_image_vq_indices: {current_image_vq_indices}") + current_image_vq_indices = torch.clamp(current_image_vq_indices, 0, 8192 - 1) + current_image = vq_model.decode_code(current_image_vq_indices) + images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = Image.fromarray(images[0]) + yield pil_images, f"Step {step + 1}/{timesteps}" + # determined by mask_ratio * unknown_number_in_the_beginning. + ratio = 1.0 * (step + 1) / timesteps + mask_ratio = noise_schedule(torch.tensor(ratio)) + # Computes the probabilities of each selected tokens. + selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) + selected_probs = selected_probs.squeeze(-1) + + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + # Gets mask lens for each sample in the batch according to the mask ratio. + mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) + # Keeps at least one of prediction in this round and also masks out at least + # one and for the next iteration + mask_len = torch.max( + torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + ) + # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") + # Adds noise for randomness + temperature = temperature * (1.0 - ratio) + masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) + # Masks tokens with lower confidence. + input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, + sampled_ids + len(uni_prompting.text_tokenizer) + + num_new_special_tokens) + input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) + + + return sampled_ids + + +AutoConfig.register("omada", OMadaConfig) +AutoModelForCausalLM.register(OMadaConfig, OMadaModelLM) +AutoModel.register(OMadaConfig, OMadaModelLM) diff --git a/MMaDA/models/modeling_utils.py b/MMaDA/models/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b991e511f473d926a317eb20029096c2322a1a7b --- /dev/null +++ b/MMaDA/models/modeling_utils.py @@ -0,0 +1,1207 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import itertools +import json +import os +import re +from collections import OrderedDict +from functools import partial +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +import safetensors +import torch +from huggingface_hub import create_repo, split_torch_state_dict_into_shards +from huggingface_hub.utils import validate_hf_hub_args +from torch import Tensor, nn + +from diffusers import __version__ +from diffusers.utils import ( + FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + WEIGHTS_INDEX_NAME, + _add_variant, + _get_checkpoint_shard_files, + _get_model_file, + deprecate, + is_accelerate_available, + is_torch_version, + logging, +) + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "pytorch_model.bin" +SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" + +from diffusers.utils.hub_utils import ( + PushToHubMixin, + load_or_create_model_card, + populate_model_card, +) +from diffusers.models.model_loading_utils import ( + _determine_device_map, + _fetch_index_file, + _load_state_dict_into_model, + load_model_dict_into_meta, + load_state_dict, +) + +from diffusers.configuration_utils import ConfigMixin, register_to_config + +logger = logging.get_logger(__name__) + +_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + +if is_torch_version(">=", "1.9.0"): + _LOW_CPU_MEM_USAGE_DEFAULT = True +else: + _LOW_CPU_MEM_USAGE_DEFAULT = False + + +if is_accelerate_available(): + import accelerate + + +def get_parameter_device(parameter: torch.nn.Module) -> torch.device: + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: + try: + params = tuple(parameter.parameters()) + if len(params) > 0: + return params[0].dtype + + buffers = tuple(parameter.buffers()) + if len(buffers) > 0: + return buffers[0].dtype + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +class ModelMixin(torch.nn.Module, PushToHubMixin): + r""" + Base class for all models. + + [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and + saving models. + + - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. + """ + + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None + + def __init__(self): + super().__init__() + + def __getattr__(self, name: str) -> Any: + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite + __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) + return self._internal_dict[name] + + # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + return super().__getattr__(name) + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self) -> None: + """ + Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). + """ + if not self._supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self) -> None: + """ + Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + def set_use_npu_flash_attention(self, valid: bool) -> None: + r""" + Set the switch for the npu flash attention. + """ + + def fn_recursive_set_npu_flash_attention(module: torch.nn.Module): + if hasattr(module, "set_use_npu_flash_attention"): + module.set_use_npu_flash_attention(valid) + + for child in module.children(): + fn_recursive_set_npu_flash_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_npu_flash_attention(module) + + def enable_npu_flash_attention(self) -> None: + r""" + Enable npu flash attention from torch_npu + + """ + self.set_use_npu_flash_attention(True) + + def disable_npu_flash_attention(self) -> None: + r""" + disable npu flash attention from torch_npu + + """ + self.set_use_npu_flash_attention(False) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None: + r""" + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up during + inference. Speed up during training is not guaranteed. + + + + āš ļø When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import UNet2DConditionModel + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> model = UNet2DConditionModel.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 + ... ) + >>> model = model.to("cuda") + >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self) -> None: + r""" + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + """ + self.set_use_memory_efficient_attention_xformers(False) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Optional[Callable] = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + max_shard_size: Union[int, str] = "10GB", + push_to_hub: bool = False, + **kwargs, + ): + """ + Save a model and its configuration file to a directory so that it can be reloaded using the + [`~models.ModelMixin.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a model and its configuration file to. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain + period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`. + This is to establish a common default size for this argument across different libraries in the Hugging + Face ecosystem (`transformers`, and `accelerate`, for example). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + weight_name_split = weights_name.split(".") + if len(weight_name_split) in [2, 3]: + weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:]) + else: + raise ValueError(f"Invalid {weights_name} provided.") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + # Only save the model itself if we are using distributed training + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # Save the model + state_dict = model_to_save.state_dict() + + # Save the model + state_dict_split = split_torch_state_dict_into_shards( + state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern + ) + + # Clean the folder from a previous save + if is_main_process: + for filename in os.listdir(save_directory): + if filename in state_dict_split.filename_to_tensors.keys(): + continue + full_filename = os.path.join(save_directory, filename) + if not os.path.isfile(full_filename): + continue + weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") + weights_without_ext = weights_without_ext.replace("{suffix}", "") + filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + if ( + filename.startswith(weights_without_ext) + and _REGEX_SHARD.fullmatch(filename_without_ext) is not None + ): + os.remove(full_filename) + + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + filepath = os.path.join(save_directory, filename) + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + else: + torch.save(shard, filepath) + + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + else: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + + if push_to_hub: + # Create a new empty model card and eventually tag it + model_card = load_or_create_model_card(repo_id, token=token) + model_card = populate_model_card(model_card) + model_card.save(Path(save_directory, "README.md").as_posix()) + + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. Defaults to `None`, meaning that the model will be loaded on CPU. + + Set `device_map="auto"` to have šŸ¤— Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # change device_map into a map if we passed an int, a str or a torch.device + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + if low_cpu_mem_usage: + if device_map is not None and not is_torch_version(">=", "1.10"): + # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. + raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + + # Determine if we're loading from a directory of sharded checkpoints. + is_sharded = False + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file = _fetch_index_file( + is_local=is_local, + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder or "", + use_safetensors=use_safetensors, + cache_dir=cache_dir, + variant=variant, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + user_agent=user_agent, + commit_hash=commit_hash, + ) + if index_file is not None and index_file.is_file(): + is_sharded = True + + if is_sharded and from_flax: + raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") + + # load model + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + if is_sharded: + sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + ) + + elif use_safetensors and not is_sharded: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if model_file is None and not is_sharded: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None and not is_sharded: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_name_or_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by default the device_map is None and the weights are loaded on the CPU + force_hook = True + device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) + if device_map is None and is_sharded: + # we load the parameters on the cpu + device_map = {"": "cpu"} + force_hook = False + try: + accelerate.load_checkpoint_and_dispatch( + model, + model_file if not is_sharded else index_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + force_hooks=force_hook, + strict=True, + ) + except AttributeError as e: + # When using accelerate loading, we do not have the ability to load the state + # dict and rename the weight names manually. Additionally, accelerate skips + # torch loading conventions and directly writes into `module.{_buffers, _parameters}` + # (which look like they should be private variables?), so we can't use the standard hooks + # to rename parameters on load. We need to mimic the original weight names so the correct + # attributes are available. After we have loaded the weights, we convert the deprecated + # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert + # the weights so we don't have to do this again. + + if "'Attention' object has no attribute" in str(e): + logger.warning( + f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" + " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" + " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," + " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," + " please also re-upload it or open a PR on the original repository." + ) + model._temp_convert_self_to_deprecated_attention_blocks() + accelerate.load_checkpoint_and_dispatch( + model, + model_file if not is_sharded else index_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + force_hooks=force_hook, + strict=True, + ) + model._undo_temp_convert_self_to_deprecated_attention_blocks() + else: + raise e + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict: OrderedDict, + resolved_archive_file, + pretrained_model_name_or_path: Union[str, os.PathLike], + ignore_mismatched_sizes: bool = False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + # Adapted from `transformers` modeling_utils.py + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + _no_split_modules = set() + modules_to_check = [self] + while len(modules_to_check) > 0: + module = modules_to_check.pop(-1) + # if the module does not appear in _no_split_modules, we also check the children + if module.__class__.__name__ not in _no_split_modules: + if isinstance(module, ModelMixin): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + modules_to_check += list(module.children()) + return list(_no_split_modules) + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (trainable or non-embedding) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters. + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embedding parameters. + + Returns: + `int`: The number of parameters. + + Example: + + ```py + from diffusers import UNet2DConditionModel + + model_id = "runwayml/stable-diffusion-v1-5" + unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") + unet.num_parameters(only_trainable=True) + 859520964 + ``` + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + # NOTE: we have to check if the deprecated parameters are in the state dict + # because it is possible we are loading from a state dict that was already + # converted + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") + + def _temp_convert_self_to_deprecated_attention_blocks(self) -> None: + deprecated_attention_block_modules = [] + + def recursive_find_attn_block(module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_modules.append(module) + + for sub_module in module.children(): + recursive_find_attn_block(sub_module) + + recursive_find_attn_block(self) + + for module in deprecated_attention_block_modules: + module.query = module.to_q + module.key = module.to_k + module.value = module.to_v + module.proj_attn = module.to_out[0] + + # We don't _have_ to delete the old attributes, but it's helpful to ensure + # that _all_ the weights are loaded into the new attributes and we're not + # making an incorrect assumption that this model should be converted when + # it really shouldn't be. + del module.to_q + del module.to_k + del module.to_v + del module.to_out + + def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None: + deprecated_attention_block_modules = [] + + def recursive_find_attn_block(module) -> None: + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_modules.append(module) + + for sub_module in module.children(): + recursive_find_attn_block(sub_module) + + recursive_find_attn_block(self) + + for module in deprecated_attention_block_modules: + module.to_q = module.query + module.to_k = module.key + module.to_v = module.value + module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) + + del module.query + del module.key + del module.value + del module.proj_attn + + +class LegacyModelMixin(ModelMixin): + r""" + A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more + pipeline-specific classes (like `DiTTransformer2DModel`). + """ + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + # To prevent dependency import problem. + from diffusers.models.model_loading_utils import _fetch_remapped_cls_from_config + + # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls. + kwargs_copy = kwargs.copy() + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, _, _ = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + # resolve remapping + remapped_class = _fetch_remapped_cls_from_config(config, cls) + + return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) \ No newline at end of file diff --git a/MMaDA/models/modeling_video_encoder.py b/MMaDA/models/modeling_video_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6214b984ae1510a287ed476bac9c44fa6571f4 --- /dev/null +++ b/MMaDA/models/modeling_video_encoder.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class VideoEncoder(nn.Module): + def __init__(self, dim, num_heads=8, dropout=0.1): + super(VideoEncoder, self).__init__() + self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True) + self.norm1 = nn.LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim), + nn.Dropout(dropout) + ) + self.norm2 = nn.LayerNorm(dim) + + def forward(self, x): + # x shape: (batch_size, seq_len, dim) + residual = x + attn_output, _ = self.attention(x, x, x) + x = self.norm1(attn_output + residual) + + residual = x + x = self.mlp(x) + x = self.norm2(x + residual) + return x # shape: (batch_size, seq_len, dim) \ No newline at end of file diff --git a/MMaDA/models/sampling.py b/MMaDA/models/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9d66c0857657d2d024629d3923c9885c44b0c8 --- /dev/null +++ b/MMaDA/models/sampling.py @@ -0,0 +1,118 @@ +# Adapted from https://github.com/lucidrains/muse-maskgit-pytorch + +import math +from functools import partial + +import torch +import torch.nn.functional as F + + +def log(t, eps=1e-20): + return torch.log(t.clamp(min=eps)) + + +def gumbel_noise(t, generator=None): + noise = torch.zeros_like(t).uniform_(0, 1, generator=generator) + return -log(-log(noise)) + + +def gumbel_sample(t, temperature=1.0, dim=-1, generator=None): + return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim) + + +def top_k(logits, thres=0.9): + k = math.ceil((1 - thres) * logits.shape[-1]) + val, ind = logits.topk(k, dim=-1) + probs = torch.full_like(logits, float("-inf")) + probs.scatter_(2, ind, val) + return probs + + +def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): + confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator) + sorted_confidence = torch.sort(confidence, dim=-1).values + cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) + masking = confidence < cut_off + return masking + + +def cosine_schedule(t): + return torch.cos(t * math.pi * 0.5) + + +def linear_schedule(t): + mask_ratio = 1 - t + mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0) + return mask_ratio + + +def pow(t, method): + exponent = float(method.replace("pow", "")) + mask_ratio = 1.0 - t**exponent + mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0) + return mask_ratio + + +def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6): + for item in [t, start, end, tau]: + item = torch.tensor(item) if not torch.is_tensor(item) else item + + # A gamma function based on sigmoid function. + v_start = torch.sigmoid(torch.tensor(start / tau)) + v_end = torch.sigmoid(torch.tensor(end / tau)) + output = torch.sigmoid((t * (end - start) + start) / tau) + output = (v_end - output) / (v_end - v_start) + return torch.clip(output, clip_min, 1.0) + + +def get_mask_schedule(method, **schedule_kwargs): + if method == "cosine": + return cosine_schedule + elif method == "linear": + return linear_schedule + elif "pow" in method: + return partial(pow, method=method) + elif method == "sigmoid": + return partial(sigmoid_schedule, **schedule_kwargs) + else: + raise ValueError("Unknown schedule method: {}".format(method)) + +def top_k_top_p_filtering( + logits: torch.Tensor, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> torch.Tensor: + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/activations.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..1214173c0f070eced400a397a59f74c227a581a9 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/activations.py @@ -0,0 +1,14 @@ +import torch +import torch.nn as nn + + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +def SiLU(): + if hasattr(torch.nn, 'SiLU'): + return torch.nn.SiLU() + else: + return Swish() diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/autoencoder.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..74301ee2c214e8eb8358f22ea064e8fd5f0e6426 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/autoencoder.py @@ -0,0 +1,40 @@ +import torch + +from nemo.collections.asr.parts.convolution_layers import ConvBlock + + +class ConvAutoencoder(torch.nn.Module): + + def __init__(self, encoder, decoder, loss_type='l2'): + super().__init__() + + self.encoder = ConvBlock(**encoder) + self.decoder = ConvBlock(**decoder) + assert loss_type == 'l2' + self.loss_type = loss_type + + def forward(self, inputs, inputs_len): + # inputs: [B, C, T] + encoded = self.encode(inputs, inputs_len) + reconstructed = self.decode(encoded, inputs_len) + + loss_mask = seq_mask(inputs_len, inputs.shape[2]) + mse_loss = torch.nn.functional.mse_loss(inputs.transpose(1, 2)[loss_mask], + reconstructed.transpose(1, 2)[loss_mask]) + + return mse_loss, reconstructed + + def encode(self, inputs, inputs_len): + encoded, _ = self.encoder(inputs, inputs_len) + return encoded + + def decode(self, encoded, inputs_len): + decoded, _ = self.decoder(encoded, inputs_len) + return decoded + + +def seq_mask(audio_lengths, max_len): + # Broadcast to vectorize creating the padding mask + padding_mask = torch.arange(max_len, device=audio_lengths.device) + padding_mask = padding_mask.expand(len(audio_lengths), max_len) < audio_lengths.unsqueeze(1) + return padding_mask diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/gaussian_upsampling.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/gaussian_upsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..c3546b08ef22c175e61b54cbd58cab6f458aac2a --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/gaussian_upsampling.py @@ -0,0 +1,60 @@ +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np + + +class GaussianUpsampling(nn.Module): + """ + Non-attention Tacotron: + - https://arxiv.org/abs/2010.04301 + this source code is implemenation of the ExpressiveTacotron from BridgetteSong + - https://github.com/BridgetteSong/ExpressiveTacotron/blob/master/model_duration.py + """ + def __init__(self, variance=1.0): + super().__init__() + self.mask_score = -1e15 + self.var = torch.tensor(variance) + self.HH_skip_init = True + + def forward(self, inputs, inputs_len, durations, output_max_len): + """ Gaussian upsampling + ------ + inputs: [B, N, H] + inputs_len : [B] + durations: phoneme durations [B, N] + vars : phoneme attended ranges [B, N] + RETURNS + ------- + upsampling_outputs: upsampled output [B, T, H] + """ + # output_len_max = int(torch.sum(durations, dim=1).max().item()) + w_t = get_upsampling_weights(durations, output_max_len, self.var, inputs_len) + + upsampling_outputs = torch.bmm(w_t.transpose(1, 2), inputs) # [B, T, encoder_hidden_size] + + return upsampling_outputs + + +def get_upsampling_weights(durations, output_max_len, variance, input_lens, mask_score=-1e15): + B, N = durations.shape + c = torch.cumsum(durations, dim=1).float() - 0.5 * durations + c = c.unsqueeze(2) # [B, N, 1] + t = torch.arange(output_max_len, device=durations.device).expand(B, N, output_max_len).float() # [B, N, T] + # Gaussian distribution density in log domain + w_t = -0.5 * (np.log(2.0 * np.pi) + torch.log(variance) + torch.pow(t - c, 2) / variance) # [B, N, T] + if input_lens is not None: + input_masks = ~get_mask_from_lengths(input_lens, N) # [B, N] + # input_masks = torch.tensor(input_masks, dtype=torch.bool, device=w_t.device) + masks = input_masks.unsqueeze(2) + w_t.data.masked_fill_(masks, mask_score) + w_t = F.softmax(w_t, dim=1) + return w_t + + +def get_mask_from_lengths(lengths, max_len=None): + if max_len is None: + max_len = max(lengths) + ids = torch.arange(max_len, device=lengths.device) + mask = (ids < lengths.reshape(-1, 1)) + return mask diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/whisper/decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/whisper/decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..c038c1581b33b67fdb7f637fdcc50efb807df22a --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/whisper/decoding.py @@ -0,0 +1,645 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Optional, Sequence, Union, TYPE_CHECKING, Any + +import numpy as np +import torch +import torch.nn.functional as F +from omegaconf import MISSING +from torch import Tensor +from torch.distributions import Categorical + + +if TYPE_CHECKING: + from .model import Whisper + + +@dataclass +class DecodingOptions: + task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" + language: Optional[str] = None # language that the audio is in; uses detected language if None + + # sampling-related options + temperature: float = 0.0 + sample_len: Optional[int] = MISSING # maximum number of tokens to sample + best_of: Optional[int] = None # number of independent samples to collect, when t > 0 + beam_size: Optional[int] = None # number of beams in beam search, when t == 0 + patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) + + # options for ranking generations (either beams or best-of-N samples) + length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm + + # prompt, prefix, and token suppression + prompt: Optional[Any] = None # text or tokens for the previous context + prefix: Optional[Any] = None # text or tokens to prefix the current context + suppress_blank: bool = False # this will suppress blank outputs + + # list of tokens ids (or comma-separated token ids) to suppress + # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` + suppress_tokens: Optional[Any] = None + + # timestamp sampling options + without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only + max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this + + # sot_id: Optional[int] = None + + # implementation details + fp16: bool = False # use fp16 for most of the calculation + + +@dataclass(frozen=True) +class DecodingResult: + audio_features: Tensor + language: str + language_probs: Optional[Dict[str, float]] = None + tokens: List[int] = field(default_factory=list) + text: str = "" + avg_logprob: float = np.nan + no_speech_prob: float = np.nan + temperature: float = np.nan + compression_ratio: float = np.nan + + +class Inference: + def logits(self, tokens: Tensor, audio_features: Tensor, audio_features_len: Tensor) -> Tensor: + """Perform a forward pass on the decoder and return per-token logits""" + raise NotImplementedError + + def rearrange_kv_cache(self, source_indices) -> None: + """Update the key-value cache according to the updated beams""" + raise NotImplementedError + + def cleanup_caching(self) -> None: + """Clean up any resources or hooks after decoding is finished""" + pass + + +class PyTorchInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length + self.kv_cache = {} + + def logits(self, tokens: Tensor, audio_features: Tensor, audio_features_len: Tensor) -> Tensor: + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] + + return self.model.decoder(tokens, None, audio_features, audio_features_len, kv_cache=self.kv_cache) + + def cleanup_caching(self): + self.kv_cache = {} + + def rearrange_kv_cache(self, source_indices): + for module, (k_tensor, v_tensor) in self.kv_cache.items(): + # update the key/value cache to contain the selected sequences + self.kv_cache[module] = k_tensor[source_indices].detach(), v_tensor[source_indices].detach() + + +class SequenceRanker: + def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]: + """ + Given a list of groups of samples and their cumulative log probabilities, + return the indices of the samples in each group to select as the final result + """ + raise NotImplementedError + + +class MaximumLikelihoodRanker(SequenceRanker): + """ + Select the sample with the highest log probabilities, penalized using either + a simple length normalization or Google NMT paper's length penalty + """ + + def __init__(self, length_penalty: Optional[float]): + self.length_penalty = length_penalty + + def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]): + def scores(logprobs, lengths): + result = [] + for logprob, length in zip(logprobs, lengths): + if self.length_penalty is None: + # +1 for EOS + penalty = length + 1 + else: + # from the Google NMT paper + penalty = ((5 + length) / 6) ** self.length_penalty + result.append(logprob / penalty) + return result + + # get the sequence with the highest score + lengths = [[len(t) for t in s] for s in tokens] + return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] + + +class TokenDecoder: + def reset(self): + """Initialize any stateful variables for decoding a new sequence""" + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + """Specify how to select the next token, based on the current trace and logits + + Parameters + ---------- + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + sum_logprobs : Tensor, shape = (n_batch) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Tensor, shape = (n_batch, current_sequence_length + 1) + the tokens, appended with the selected next token + + completed : bool + True if all sequences has reached the end of text + + """ + raise NotImplementedError + + def finalize( + self, tokens: Tensor, sum_logprobs: Tensor + ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: + """Finalize search and return the final candidate sequences + + Parameters + ---------- + tokens : Tensor, shape = (n_audio, n_group, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence + + sum_logprobs : Tensor, shape = (n_audio, n_group) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Sequence[Sequence[Tensor]], length = n_audio + sequence of Tensors containing candidate token sequences, for each audio input + + sum_logprobs : List[List[float]], length = n_audio + sequence of cumulative log probabilities corresponding to the above + + """ + raise NotImplementedError + + +class GreedyDecoder(TokenDecoder): + def __init__(self, temperature: float, eot: int): + self.temperature = temperature + self.eot = eot + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + temperature = self.temperature + if temperature == 0: + next_tokens = logits.argmax(dim=-1) + else: + next_tokens = Categorical(logits=logits / temperature).sample() + + logprobs = F.log_softmax(logits.float(), dim=-1) + current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] + sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) + + next_tokens[tokens[:, -1] == self.eot] = self.eot + tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) + + completed = (tokens[:, -1] == self.eot).all() + return tokens, completed + + def finalize(self, tokens: Tensor, sum_logprobs: Tensor): + # make sure each sequence has at least one EOT token at the end + tokens = F.pad(tokens, (0, 1), value=self.eot) + return tokens, sum_logprobs.tolist() + + +class BeamSearchDecoder(TokenDecoder): + def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): + self.beam_size = beam_size + self.eot = eot + self.inference = inference + self.patience = patience or 1.0 + self.max_candidates: int = round(beam_size * self.patience) + self.finished_sequences = None + + assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" + + def reset(self): + self.finished_sequences = None + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + if tokens.shape[0] % self.beam_size != 0: + raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") + + n_audio = tokens.shape[0] // self.beam_size + if self.finished_sequences is None: # for the first update + self.finished_sequences = [{} for _ in range(n_audio)] + + logprobs = F.log_softmax(logits.float(), dim=-1) + next_tokens, source_indices, finished_sequences = [], [], [] + for i in range(n_audio): + scores, sources, finished = {}, {}, {} + + # STEP 1: calculate the cumulative log probabilities for possible candidates + for j in range(self.beam_size): + idx = i * self.beam_size + j + prefix = tokens[idx].tolist() + for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): + new_logprob = (sum_logprobs[idx] + logprob).item() + sequence = tuple(prefix + [token.item()]) + scores[sequence] = new_logprob + sources[sequence] = idx + + # STEP 2: rank the candidates and keep the top beam_size sequences for each audio + saved = 0 + for sequence in sorted(scores, key=scores.get, reverse=True): + if sequence[-1] == self.eot: + finished[sequence] = scores[sequence] + else: + sum_logprobs[len(next_tokens)] = scores[sequence] + next_tokens.append(sequence) + source_indices.append(sources[sequence]) + + saved += 1 + if saved == self.beam_size: + break + + finished_sequences.append(finished) + + tokens = torch.tensor(next_tokens, device=tokens.device) + self.inference.rearrange_kv_cache(source_indices) + + # add newly finished sequences to self.finished_sequences + assert len(self.finished_sequences) == len(finished_sequences) + for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): + for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): + if len(previously_finished) >= self.max_candidates: + break # the candidate list is full + previously_finished[seq] = newly_finished[seq] + + # mark as completed if all audio has enough number of samples + completed = all( + len(sequences) >= self.max_candidates for sequences in self.finished_sequences + ) + return tokens, completed + + def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): + # collect all finished sequences, including patience, and add unfinished ones if not enough + sum_logprobs = sum_logprobs.cpu() + for i, sequences in enumerate(self.finished_sequences): + if len(sequences) < self.beam_size: # when not enough sequences are finished + for j in list(np.argsort(sum_logprobs[i]))[::-1]: + sequence = preceding_tokens[i, j].tolist() + [self.eot] + sequences[tuple(sequence)] = sum_logprobs[i][j].item() + if len(sequences) >= self.beam_size: + break + + tokens: List[List[Tensor]] = [ + [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences + ] + sum_logprobs: List[List[float]] = [ + list(sequences.values()) for sequences in self.finished_sequences + ] + return tokens, sum_logprobs + + +class LogitFilter: + def apply(self, logits: Tensor, tokens: Tensor) -> None: + """Apply any filtering or masking to logits in-place + + Parameters + ---------- + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + """ + raise NotImplementedError + + +class SuppressBlank(LogitFilter): + def __init__(self, tokenizer, sample_begin: int): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + + def apply(self, logits: Tensor, tokens: Tensor): + if tokens.shape[1] == self.sample_begin: + logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf + + +class SuppressTokens(LogitFilter): + def __init__(self, suppress_tokens: Sequence[int]): + self.suppress_tokens = list(suppress_tokens) + + def apply(self, logits: Tensor, tokens: Tensor): + logits[:, self.suppress_tokens] = -np.inf + + +class ApplyTimestampRules(LogitFilter): + def __init__( + self, tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int] + ): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + self.max_initial_timestamp_index = max_initial_timestamp_index + + def apply(self, logits: Tensor, tokens: Tensor): + # suppress <|notimestamps|> which is handled by without_timestamps + if self.tokenizer.no_timestamps is not None: + logits[:, self.tokenizer.no_timestamps] = -np.inf + + # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + for k in range(tokens.shape[0]): + seq = [t for t in tokens[k, self.sample_begin :].tolist()] + last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin + penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin + + if last_was_timestamp: + if penultimate_was_timestamp: # has to be non-timestamp + logits[k, self.tokenizer.timestamp_begin :] = -np.inf + else: # cannot be normal text tokens + logits[k, : self.tokenizer.eot] = -np.inf + + # apply the `max_initial_timestamp` option + if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None: + last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + logits[:, last_allowed + 1 :] = -np.inf + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = F.log_softmax(logits.float(), dim=-1) + for k in range(tokens.shape[0]): + timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) + max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() + if timestamp_logprob > max_text_token_logprob: + logits[k, : self.tokenizer.timestamp_begin] = -np.inf + + +class DecodingTask: + inference: Inference + sequence_ranker: SequenceRanker + decoder: TokenDecoder + logit_filters: List[LogitFilter] + + def __init__(self, model: "Whisper", tokenizer, options: DecodingOptions): + self.model = model + + # language = options.language or "en" + # tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task) + self.tokenizer = tokenizer + sot_id = self.tokenizer.tokenizer.bos_id() + eot_id = self.tokenizer.tokenizer.eos_id() + assert sot_id >= 0 and eot_id >= 0 + self.options: DecodingOptions = self._verify_options(options) + + self.n_group: int = options.beam_size or options.best_of or 1 + # self.n_ctx: int = model.dims.n_text_ctx + # self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 + self.sample_len: int = options.sample_len + + # self.sot_sequence: Tuple[int] = tokenizer.sot_sequence + # if self.options.without_timestamps: + # self.sot_sequence = tokenizer.sot_sequence_including_notimestamps + self.sot_sequence: Tuple[int] = (sot_id,) + + self.initial_tokens: Tuple[int] = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.sot_index: int = self.initial_tokens.index(sot_id) + assert self.sample_begin == 1 + self.eot_id = eot_id + + # inference: implements the forward pass through the decoder, including kv caching + self.inference = PyTorchInference(model, len(self.initial_tokens)) + + # sequence ranker: implements how to rank a group of sampled sequences + self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) + + # decoder: implements how to select the next tokens, given the autoregressive distribution + if options.beam_size is not None: + self.decoder = BeamSearchDecoder( + options.beam_size, self.eot_id, self.inference, options.patience + ) + else: + self.decoder = GreedyDecoder(options.temperature, self.eot_id) + + # logit filters: applies various rules to suppress or penalize certain tokens + self.logit_filters = [] + if self.options.suppress_blank: + self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin)) + if self.options.suppress_tokens: + self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) + # if not options.without_timestamps: + # precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds + # max_initial_timestamp_index = None + # if options.max_initial_timestamp: + # max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) + # self.logit_filters.append( + # ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) + # ) + + def _verify_options(self, options: DecodingOptions) -> DecodingOptions: + if options.beam_size is not None and options.best_of is not None: + raise ValueError("beam_size and best_of can't be given together") + if options.temperature == 0: + if options.best_of is not None: + raise ValueError("best_of with greedy sampling (T=0) is not compatible") + if options.patience is not None and options.beam_size is None: + raise ValueError("patience requires beam_size to be given") + if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): + raise ValueError("length_penalty (alpha) should be a value between 0 and 1") + + return options + + def _get_initial_tokens(self) -> Tuple[int]: + tokens = list(self.sot_sequence) + prefix = self.options.prefix + prompt = self.options.prompt + + if prefix: + prefix_tokens = ( + self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix + ) + # if self.sample_len is not None: + # max_prefix_len = self.n_ctx // 2 - self.sample_len + # prefix_tokens = prefix_tokens[-max_prefix_len:] + tokens = tokens + prefix_tokens + + if prompt: + prompt_tokens = ( + self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt + ) + # tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens + tokens = [self.tokenizer.sot_prev] + prompt_tokens + tokens + + return tuple(tokens) + + def _get_suppress_tokens(self) -> Tuple[int]: + suppress_tokens = self.options.suppress_tokens + + if isinstance(suppress_tokens, str): + suppress_tokens = [int(t) for t in suppress_tokens.split(",")] + + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(self.tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" + + suppress_tokens.extend( + [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm] + ) + if self.tokenizer.no_speech is not None: + # no-speech probability is collected separately + suppress_tokens.append(self.tokenizer.no_speech) + + return tuple(sorted(set(suppress_tokens))) + + def _get_audio_features(self, mel: Tensor, mel_len: Tensor): + if self.options.fp16: + mel = mel.half() + + # if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): + # encoded audio features are given; skip audio encoding + # audio_features = mel + # else: + audio_features, audio_features_len = self.model.embed_audio(mel, mel_len) + + if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): + return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") + + return audio_features, audio_features_len + + def _detect_language(self, audio_features: Tensor, tokens: Tensor): + languages = [self.options.language] * audio_features.shape[0] + lang_probs = None + + if self.options.language is None or self.options.task == "lang_id": + lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) + languages = [max(probs, key=probs.get) for probs in lang_probs] + if self.options.language is None: + tokens[:, self.sot_index + 1] = lang_tokens # write language tokens + + return languages, lang_probs + + def _main_loop(self, audio_features: Tensor, audio_features_len: Tensor, tokens: Tensor): + assert audio_features.shape[0] == tokens.shape[0] + n_batch = tokens.shape[0] + sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) + # no_speech_probs = [np.nan] * n_batch + + try: + for i in range(self.sample_len): + logits = self.inference.logits(tokens, audio_features, audio_features_len) + + # if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs + # probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) + # no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() + + # now we need to consider the logits at the last token only + logits = logits[:, -1] + + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logit_filter.apply(logits, tokens) + + # expand the tokens tensor with the selected next tokens + tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + + # if completed or tokens.shape[-1] > self.n_ctx: + if completed: + break + finally: + self.inference.cleanup_caching() + + return tokens, sum_logprobs + + @torch.no_grad() + def run(self, mel: Tensor, mel_len: Tensor): + self.decoder.reset() + tokenizer = self.tokenizer + n_audio: int = mel.shape[0] + + audio_features, audio_features_len = self._get_audio_features(mel, mel_len) # encoder forward pass + tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) + + # detect language if requested, overwriting the language token + # languages, language_probs = self._detect_language(audio_features, tokens) + # if self.options.task == "lang_id": + # return [ + # DecodingResult(audio_features=features, language=language, language_probs=probs) + # for features, language, probs in zip(audio_features, languages, language_probs) + # ] + + # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling + audio_features = audio_features.repeat_interleave(self.n_group, dim=0) + audio_features_len = audio_features_len.repeat_interleave(self.n_group, dim=0) + tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) + + # call the main sampling loop + tokens, sum_logprobs, = self._main_loop(audio_features, audio_features_len, tokens) + + # reshape the tensors to have (n_audio, n_group) as the first two dimensions + audio_features = audio_features[:: self.n_group] + # no_speech_probs = no_speech_probs[:: self.n_group] + assert audio_features.shape[0] == n_audio + + tokens = tokens.reshape(n_audio, self.n_group, -1) + sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) + + # get the final candidates for each group, and slice between the first sampled token and EOT + tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) + tokens: List[List[Tensor]] = [ + [t[self.sample_begin : (t == self.eot_id).nonzero()[0, 0]] for t in s] for s in tokens + ] + + # select the top-ranked sample in each group + selected = self.sequence_ranker.rank(tokens, sum_logprobs) + tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] + texts: List[str] = [tokenizer.ids_to_text(t).strip() for t in tokens] + + sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] + avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] + + fields = (texts, tokens, avg_logprobs) + if len(set(map(len, fields))) != 1: + raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") + + return fields + # return [(tokens, avg_logprob, no_speech_prob) for tokens, avg_logprob, no_speech_prob in zip(*fields)] + + +@torch.no_grad() +def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]: + """ + Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). + + Parameters + ---------- + model: Whisper + the Whisper model instance + + mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) + A tensor containing the Mel spectrogram(s) + + options: DecodingOptions + A dataclass that contains all necessary options for decoding 30-second segments + + Returns + ------- + result: Union[DecodingResult, List[DecodingResult]] + The result(s) of decoding contained in `DecodingResult` dataclass instance(s) + """ + single = mel.ndim == 2 + if single: + mel = mel.unsqueeze(0) + + result = DecodingTask(model, options).run(mel) + + if single: + result = result[0] + + return result diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/whisper/model.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/whisper/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6d98b3c137d84470241f9d9ea75bace5e59aff36 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/hmodels/modules/whisper/model.py @@ -0,0 +1,230 @@ +import math +from typing import Dict +from typing import Iterable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn + +from htrain.util import instantiate_from_config +from nemo.collections.asr.parts.convolution_layers import create_pad_mask +from .decoding import decode as decode_function + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int, dropout: float, n_kv: int = 0): + super().__init__() + self.n_head = n_head + self.query = nn.Linear(n_state, n_state) + self.qkv_same_dim = n_kv == n_state + self.key = nn.Linear(n_kv if n_kv else n_state, n_state, bias=False) + self.value = nn.Linear(n_kv if n_kv else n_state, n_state) + self.dropout = nn.Dropout(dropout) + self.out = nn.Linear(n_state, n_state) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.query.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.key.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.value.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.query.weight) + nn.init.xavier_uniform_(self.key.weight) + nn.init.xavier_uniform_(self.value.weight) + + nn.init.xavier_uniform_(self.out.weight) + if self.out.bias is not None: + nn.init.constant_(self.out.bias, 0.0) + # if self.value.bias is not None: + # nn.init.xavier_normal_(self.value.bias) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + self_attn_mask: Optional[Tensor] = None, + xattn_padding_mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None: + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # decoding + if xa is None: + # for self-attention, calculate keys and values and concat with previous keys and values + k = self.key(x) + v = self.value(x) + prev_kv = kv_cache.get(self) + if prev_kv: + prev_k, prev_v = prev_kv + k = torch.cat([prev_k, k], dim=1) + v = torch.cat([prev_v, v], dim=1) + kv_cache[self] = k.detach(), v.detach() + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + kv = kv_cache.get(self) + if kv: + k, v = kv + else: + k = self.key(xa) + v = self.value(xa) + kv_cache[self] = k.detach(), v.detach() + + wv = self.qkv_attention(q, k, v, self_attn_mask, xattn_padding_mask) + return self.out(wv) + + def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, self_attn_mask: Optional[Tensor] = None, xattn_padding_mask: Optional[Tensor] = None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + # [B, n_head, q_len, k_len] + qk = q @ k + + if self_attn_mask is not None: + qk = qk + self_attn_mask + + if xattn_padding_mask is not None: + # don't attend to padding symbols + qk = qk.masked_fill(xattn_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")) + + w = F.softmax(qk.float(), dim=-1).to(q.dtype) + w = self.dropout(w) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, n_mlp: int, dropout: float, dropout_attn: float, dropout_mlp: float, + *, act_fn='gelu', cross_attn: bool = False, n_xattn_kv: int = 0): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head, dropout_attn) + self.attn_ln = nn.LayerNorm(n_state) + + self.cross_attn = MultiHeadAttention(n_state, n_head, dropout_attn, n_xattn_kv) if cross_attn else None + self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attn else None + + self.dropout = nn.Dropout(dropout) + + if act_fn == 'gelu': + act = nn.GELU() + else: + assert act_fn == 'relu' + act = nn.ReLU() + self.mlp = nn.Sequential(nn.Linear(n_state, n_mlp), act, nn.Dropout(dropout_mlp), nn.Linear(n_mlp, n_state)) + self.mlp_ln = nn.LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + self_attn_mask: Optional[Tensor] = None, + xattn_padding_mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.dropout( + self.attn(self.attn_ln(x), self_attn_mask=self_attn_mask, kv_cache=kv_cache) + ) + if self.cross_attn: + x = x + self.dropout(self.cross_attn(self.cross_attn_ln(x), xa, xattn_padding_mask=xattn_padding_mask, kv_cache=kv_cache)) + x = x + self.dropout(self.mlp(self.mlp_ln(x))) + return x + + +class TextDecoder(nn.Module): + def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_mlp: int, n_layer: int, + n_encoder_state: int, dropout: float, dropout_attn: float, dropout_mlp: float, layer_drop: float, + act_fn: str = 'gelu'): + super().__init__() + + token_embedding = nn.Embedding(n_vocab, n_state) + nn.init.normal_(token_embedding.weight, mean=0, std=n_state ** -0.5) + self.token_embedding = token_embedding + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state).normal_(mean=0, std=n_state ** -0.5)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock( + n_state, n_head, n_mlp, dropout=dropout, dropout_attn=dropout_attn, dropout_mlp=dropout_mlp, + cross_attn=True, n_xattn_kv=n_encoder_state, act_fn=act_fn + ) + for _ in range(n_layer)] + ) + self.layer_drop = layer_drop + self.ln = nn.LayerNorm(n_state) + + self._future_mask_buffer = torch.empty(0) + + def forward(self, x: Tensor, x_len: Tensor, xa: Tensor, xa_len: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) + the encoded audio features to be attended on + """ + if kv_cache: + offset = next(iter(kv_cache.values()))[0].shape[1] + else: + offset = 0 + x = self.token_embedding(x) + self.positional_embedding[offset: offset + x.shape[-1]] + x = x.to(xa.dtype) + + future_mask = self.get_future_mask(x) + encoder_padding_mask = create_pad_mask(xa_len, xa.shape[1]) + for block in self.blocks: + if self.training and self.layer_drop > 0 and np.random.random() < self.layer_drop: + continue + x = block(x, xa, self_attn_mask=future_mask, xattn_padding_mask=encoder_padding_mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() + + return logits + + def get_future_mask(self, tensor): + dim = tensor.size(1) + if self._future_mask_buffer.size(0) < dim: + self._future_mask_buffer = torch.empty(dim, dim).fill_(float('-inf')).triu_(1) + self._future_mask_buffer = self._future_mask_buffer.to(tensor) + return self._future_mask_buffer[:dim, :dim] + + +class Whisper(nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + + self.encoder = instantiate_from_config(encoder) + self.decoder = TextDecoder(**decoder) + + def embed_audio(self, mel: torch.Tensor, mel_len: torch.Tensor): + feat, feat_len, _ = self.encoder.forward(mel, mel_len) + # [B, D, T] => [B, T, D] + feat = torch.transpose(feat, 1, 2) + return feat, feat_len + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): + return self.decoder.forward(tokens, audio_features) + + def forward(self, mel: torch.Tensor, mel_len: torch.Tensor, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Dict[str, torch.Tensor]: + feat, feat_len, _ = self.encoder(mel, mel_len) + # [B, D, T] => [B, T, D] + feat = torch.transpose(feat, 1, 2) + + output = self.decoder(tokens, tokens_len, feat, feat_len) + return output + + @property + def device(self): + return next(self.parameters()).device + + decode = decode_function diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_S2U_model/SPIRAL2_base_mutilingual_wenet_lv13k960_pretrain_aishell1_ls100_finetune_FSQ_8888_CTC_4ConvDec_phone_40ms_1n8g_20/hparams.yaml b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_S2U_model/SPIRAL2_base_mutilingual_wenet_lv13k960_pretrain_aishell1_ls100_finetune_FSQ_8888_CTC_4ConvDec_phone_40ms_1n8g_20/hparams.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d969cbf23f1a8457931f86c63a04731edf69fa08 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_S2U_model/SPIRAL2_base_mutilingual_wenet_lv13k960_pretrain_aishell1_ls100_finetune_FSQ_8888_CTC_4ConvDec_phone_40ms_1n8g_20/hparams.yaml @@ -0,0 +1,1643 @@ +train_ds: + batch_size: 14 + drop_last: false + shuffle: true + num_workers: 4 + pin_memory: true + manifest_dir: /home/ma-user/work/taodehua/data/multi_lingual_AISHELL1_LibriSpeech + data_dir: /home/ma-user/work/taodehua/data/multi_lingual_AISHELL1_LibriSpeech + manifest_filepath: phone_sequence_aishell1_train.json,phone_sequence_librispeech_train-clean-100.json + sample_rate: 16000 + labels: + - en_AA0 + - en_AA1 + - en_AA2 + - en_AE0 + - en_AE1 + - en_AE2 + - en_AH0 + - en_AH1 + - en_AH2 + - en_AO0 + - en_AO1 + - en_AO2 + - en_AW0 + - en_AW1 + - en_AW2 + - en_AY0 + - en_AY1 + - en_AY2 + - en_B + - en_CH + - en_D + - en_DH + - en_EH0 + - en_EH1 + - en_EH2 + - en_ER0 + - en_ER1 + - en_ER2 + - en_EY0 + - en_EY1 + - en_EY2 + - en_F + - en_G + - en_HH + - en_IH0 + - en_IH1 + - en_IH2 + - en_IY0 + - en_IY1 + - en_IY2 + - en_JH + - en_K + - en_L + - en_M + - en_N + - en_NG + - en_OW0 + - en_OW1 + - en_OW2 + - en_OY0 + - en_OY1 + - en_OY2 + - en_P + - en_R + - en_S + - en_SH + - en_T + - en_TH + - en_UH0 + - en_UH1 + - en_UH2 + - en_UW0 + - en_UW1 + - en_UW2 + - en_V + - en_W + - en_Y + - en_Z + - en_ZH + - CN_a1 + - CN_a2 + - CN_a3 + - CN_a4 + - CN_a5 + - CN_ai1 + - CN_ai2 + - CN_ai3 + - CN_ai4 + - CN_an1 + - CN_an2 + - CN_an3 + - CN_an4 + - CN_ang1 + - CN_ang2 + - CN_ang3 + - CN_ang4 + - CN_ao1 + - CN_ao2 + - CN_ao3 + - CN_ao4 + - CN_b + - CN_c + - CN_ch + - CN_d + - CN_e1 + - CN_e2 + - CN_e3 + - CN_e4 + - CN_e5 + - CN_ei1 + - CN_ei2 + - CN_ei3 + - CN_ei4 + - CN_en1 + - CN_en2 + - CN_en3 + - CN_en4 + - CN_en5 + - CN_eng1 + - CN_eng2 + - CN_eng3 + - CN_eng4 + - CN_f + - CN_g + - CN_h + - CN_i1 + - CN_i2 + - CN_i3 + - CN_i4 + - CN_ia1 + - CN_ia2 + - CN_ia3 + - CN_ia4 + - CN_ian1 + - CN_ian2 + - CN_ian3 + - CN_ian4 + - CN_iang1 + - CN_iang2 + - CN_iang3 + - CN_iang4 + - CN_iao1 + - CN_iao2 + - CN_iao3 + - CN_iao4 + - CN_ie1 + - CN_ie2 + - CN_ie3 + - CN_ie4 + - CN_ii1 + - CN_ii2 + - CN_ii3 + - CN_ii4 + - CN_ii5 + - CN_iii1 + - CN_iii2 + - CN_iii3 + - CN_iii4 + - CN_in1 + - CN_in2 + - CN_in3 + - CN_in4 + - CN_ing1 + - CN_ing2 + - CN_ing3 + - CN_ing4 + - CN_iong1 + - CN_iong2 + - CN_iong3 + - CN_iong4 + - CN_iou1 + - CN_iou2 + - CN_iou3 + - CN_iou4 + - CN_j + - CN_k + - CN_l + - CN_m + - CN_n + - CN_o1 + - CN_o2 + - CN_o3 + - CN_o4 + - CN_ong1 + - CN_ong2 + - CN_ong3 + - CN_ong4 + - CN_ou1 + - CN_ou2 + - CN_ou3 + - CN_ou4 + - CN_p + - CN_q + - CN_r + - CN_rr + - CN_s + - CN_sh + - CN_t + - CN_u1 + - CN_u2 + - CN_u3 + - CN_u4 + - CN_ua1 + - CN_ua2 + - CN_ua3 + - CN_ua4 + - CN_uai1 + - CN_uai2 + - CN_uai3 + - CN_uai4 + - CN_uan1 + - CN_uan2 + - CN_uan3 + - CN_uan4 + - CN_uang1 + - CN_uang2 + - CN_uang3 + - CN_uang4 + - CN_ui1 + - CN_ui2 + - CN_ui3 + - CN_ui4 + - CN_un1 + - CN_un2 + - CN_un3 + - CN_un4 + - CN_uo1 + - CN_uo2 + - CN_uo3 + - CN_uo4 + - CN_uo5 + - CN_v1 + - CN_v2 + - CN_v3 + - CN_v4 + - CN_van1 + - CN_van2 + - CN_van3 + - CN_van4 + - CN_ve1 + - CN_ve2 + - CN_ve3 + - CN_ve4 + - CN_vn1 + - CN_vn2 + - CN_vn3 + - CN_vn4 + - CN_x + - CN_z + - CN_zh + trim_silence: false + int_values: false + augmentor: null + max_duration: 24.0 + min_duration: null + max_utts: 0 + dup_factor: 1 + blank_index: -1 + unk_index: -1 + normalize: false + trim: true + load_audio: true + parser: en + parser_add_end_space: false + add_misc: false + subword_sampling_nbest_size: null + subword_sampling_alpha: null +validation_ds: + batch_size: 14 + drop_last: false + shuffle: false + num_workers: 4 + pin_memory: true + manifest_dir: /home/ma-user/work/taodehua/data/multi_lingual_AISHELL1_LibriSpeech + data_dir: /home/ma-user/work/taodehua/data/multi_lingual_AISHELL1_LibriSpeech + manifest_filepath: phone_sequence_aishell1_dev.json,phone_sequence_librispeech_dev-other.json + sample_rate: 16000 + labels: + - en_AA0 + - en_AA1 + - en_AA2 + - en_AE0 + - en_AE1 + - en_AE2 + - en_AH0 + - en_AH1 + - en_AH2 + - en_AO0 + - en_AO1 + - en_AO2 + - en_AW0 + - en_AW1 + - en_AW2 + - en_AY0 + - en_AY1 + - en_AY2 + - en_B + - en_CH + - en_D + - en_DH + - en_EH0 + - en_EH1 + - en_EH2 + - en_ER0 + - en_ER1 + - en_ER2 + - en_EY0 + - en_EY1 + - en_EY2 + - en_F + - en_G + - en_HH + - en_IH0 + - en_IH1 + - en_IH2 + - en_IY0 + - en_IY1 + - en_IY2 + - en_JH + - en_K + - en_L + - en_M + - en_N + - en_NG + - en_OW0 + - en_OW1 + - en_OW2 + - en_OY0 + - en_OY1 + - en_OY2 + - en_P + - en_R + - en_S + - en_SH + - en_T + - en_TH + - en_UH0 + - en_UH1 + - en_UH2 + - en_UW0 + - en_UW1 + - en_UW2 + - en_V + - en_W + - en_Y + - en_Z + - en_ZH + - CN_a1 + - CN_a2 + - CN_a3 + - CN_a4 + - CN_a5 + - CN_ai1 + - CN_ai2 + - CN_ai3 + - CN_ai4 + - CN_an1 + - CN_an2 + - CN_an3 + - CN_an4 + - CN_ang1 + - CN_ang2 + - CN_ang3 + - CN_ang4 + - CN_ao1 + - CN_ao2 + - CN_ao3 + - CN_ao4 + - CN_b + - CN_c + - CN_ch + - CN_d + - CN_e1 + - CN_e2 + - CN_e3 + - CN_e4 + - CN_e5 + - CN_ei1 + - CN_ei2 + - CN_ei3 + - CN_ei4 + - CN_en1 + - CN_en2 + - CN_en3 + - CN_en4 + - CN_en5 + - CN_eng1 + - CN_eng2 + - CN_eng3 + - CN_eng4 + - CN_f + - CN_g + - CN_h + - CN_i1 + - CN_i2 + - CN_i3 + - CN_i4 + - CN_ia1 + - CN_ia2 + - CN_ia3 + - CN_ia4 + - CN_ian1 + - CN_ian2 + - CN_ian3 + - CN_ian4 + - CN_iang1 + - CN_iang2 + - CN_iang3 + - CN_iang4 + - CN_iao1 + - CN_iao2 + - CN_iao3 + - CN_iao4 + - CN_ie1 + - CN_ie2 + - CN_ie3 + - CN_ie4 + - CN_ii1 + - CN_ii2 + - CN_ii3 + - CN_ii4 + - CN_ii5 + - CN_iii1 + - CN_iii2 + - CN_iii3 + - CN_iii4 + - CN_in1 + - CN_in2 + - CN_in3 + - CN_in4 + - CN_ing1 + - CN_ing2 + - CN_ing3 + - CN_ing4 + - CN_iong1 + - CN_iong2 + - CN_iong3 + - CN_iong4 + - CN_iou1 + - CN_iou2 + - CN_iou3 + - CN_iou4 + - CN_j + - CN_k + - CN_l + - CN_m + - CN_n + - CN_o1 + - CN_o2 + - CN_o3 + - CN_o4 + - CN_ong1 + - CN_ong2 + - CN_ong3 + - CN_ong4 + - CN_ou1 + - CN_ou2 + - CN_ou3 + - CN_ou4 + - CN_p + - CN_q + - CN_r + - CN_rr + - CN_s + - CN_sh + - CN_t + - CN_u1 + - CN_u2 + - CN_u3 + - CN_u4 + - CN_ua1 + - CN_ua2 + - CN_ua3 + - CN_ua4 + - CN_uai1 + - CN_uai2 + - CN_uai3 + - CN_uai4 + - CN_uan1 + - CN_uan2 + - CN_uan3 + - CN_uan4 + - CN_uang1 + - CN_uang2 + - CN_uang3 + - CN_uang4 + - CN_ui1 + - CN_ui2 + - CN_ui3 + - CN_ui4 + - CN_un1 + - CN_un2 + - CN_un3 + - CN_un4 + - CN_uo1 + - CN_uo2 + - CN_uo3 + - CN_uo4 + - CN_uo5 + - CN_v1 + - CN_v2 + - CN_v3 + - CN_v4 + - CN_van1 + - CN_van2 + - CN_van3 + - CN_van4 + - CN_ve1 + - CN_ve2 + - CN_ve3 + - CN_ve4 + - CN_vn1 + - CN_vn2 + - CN_vn3 + - CN_vn4 + - CN_x + - CN_z + - CN_zh + trim_silence: false + int_values: false + augmentor: null + max_duration: null + min_duration: null + max_utts: 0 + dup_factor: 1 + blank_index: -1 + unk_index: -1 + normalize: false + trim: true + load_audio: true + parser: en + parser_add_end_space: false + add_misc: false + subword_sampling_nbest_size: null + subword_sampling_alpha: null +test_ds: + batch_size: 14 + drop_last: false + shuffle: false + num_workers: 4 + pin_memory: true + manifest_dir: /home/ma-user/work/taodehua/data/multi_lingual_AISHELL1_LibriSpeech + data_dir: /home/ma-user/work/taodehua/data/multi_lingual_AISHELL1_LibriSpeech + manifest_filepath: phone_sequence_aishell1_test.json,phone_sequence_librispeech_test-clean.json + sample_rate: 16000 + labels: + - en_AA0 + - en_AA1 + - en_AA2 + - en_AE0 + - en_AE1 + - en_AE2 + - en_AH0 + - en_AH1 + - en_AH2 + - en_AO0 + - en_AO1 + - en_AO2 + - en_AW0 + - en_AW1 + - en_AW2 + - en_AY0 + - en_AY1 + - en_AY2 + - en_B + - en_CH + - en_D + - en_DH + - en_EH0 + - en_EH1 + - en_EH2 + - en_ER0 + - en_ER1 + - en_ER2 + - en_EY0 + - en_EY1 + - en_EY2 + - en_F + - en_G + - en_HH + - en_IH0 + - en_IH1 + - en_IH2 + - en_IY0 + - en_IY1 + - en_IY2 + - en_JH + - en_K + - en_L + - en_M + - en_N + - en_NG + - en_OW0 + - en_OW1 + - en_OW2 + - en_OY0 + - en_OY1 + - en_OY2 + - en_P + - en_R + - en_S + - en_SH + - en_T + - en_TH + - en_UH0 + - en_UH1 + - en_UH2 + - en_UW0 + - en_UW1 + - en_UW2 + - en_V + - en_W + - en_Y + - en_Z + - en_ZH + - CN_a1 + - CN_a2 + - CN_a3 + - CN_a4 + - CN_a5 + - CN_ai1 + - CN_ai2 + - CN_ai3 + - CN_ai4 + - CN_an1 + - CN_an2 + - CN_an3 + - CN_an4 + - CN_ang1 + - CN_ang2 + - CN_ang3 + - CN_ang4 + - CN_ao1 + - CN_ao2 + - CN_ao3 + - CN_ao4 + - CN_b + - CN_c + - CN_ch + - CN_d + - CN_e1 + - CN_e2 + - CN_e3 + - CN_e4 + - CN_e5 + - CN_ei1 + - CN_ei2 + - CN_ei3 + - CN_ei4 + - CN_en1 + - CN_en2 + - CN_en3 + - CN_en4 + - CN_en5 + - CN_eng1 + - CN_eng2 + - CN_eng3 + - CN_eng4 + - CN_f + - CN_g + - CN_h + - CN_i1 + - CN_i2 + - CN_i3 + - CN_i4 + - CN_ia1 + - CN_ia2 + - CN_ia3 + - CN_ia4 + - CN_ian1 + - CN_ian2 + - CN_ian3 + - CN_ian4 + - CN_iang1 + - CN_iang2 + - CN_iang3 + - CN_iang4 + - CN_iao1 + - CN_iao2 + - CN_iao3 + - CN_iao4 + - CN_ie1 + - CN_ie2 + - CN_ie3 + - CN_ie4 + - CN_ii1 + - CN_ii2 + - CN_ii3 + - CN_ii4 + - CN_ii5 + - CN_iii1 + - CN_iii2 + - CN_iii3 + - CN_iii4 + - CN_in1 + - CN_in2 + - CN_in3 + - CN_in4 + - CN_ing1 + - CN_ing2 + - CN_ing3 + - CN_ing4 + - CN_iong1 + - CN_iong2 + - CN_iong3 + - CN_iong4 + - CN_iou1 + - CN_iou2 + - CN_iou3 + - CN_iou4 + - CN_j + - CN_k + - CN_l + - CN_m + - CN_n + - CN_o1 + - CN_o2 + - CN_o3 + - CN_o4 + - CN_ong1 + - CN_ong2 + - CN_ong3 + - CN_ong4 + - CN_ou1 + - CN_ou2 + - CN_ou3 + - CN_ou4 + - CN_p + - CN_q + - CN_r + - CN_rr + - CN_s + - CN_sh + - CN_t + - CN_u1 + - CN_u2 + - CN_u3 + - CN_u4 + - CN_ua1 + - CN_ua2 + - CN_ua3 + - CN_ua4 + - CN_uai1 + - CN_uai2 + - CN_uai3 + - CN_uai4 + - CN_uan1 + - CN_uan2 + - CN_uan3 + - CN_uan4 + - CN_uang1 + - CN_uang2 + - CN_uang3 + - CN_uang4 + - CN_ui1 + - CN_ui2 + - CN_ui3 + - CN_ui4 + - CN_un1 + - CN_un2 + - CN_un3 + - CN_un4 + - CN_uo1 + - CN_uo2 + - CN_uo3 + - CN_uo4 + - CN_uo5 + - CN_v1 + - CN_v2 + - CN_v3 + - CN_v4 + - CN_van1 + - CN_van2 + - CN_van3 + - CN_van4 + - CN_ve1 + - CN_ve2 + - CN_ve3 + - CN_ve4 + - CN_vn1 + - CN_vn2 + - CN_vn3 + - CN_vn4 + - CN_x + - CN_z + - CN_zh + trim_silence: false + int_values: false + augmentor: null + max_duration: null + min_duration: null + max_utts: 0 + dup_factor: 1 + blank_index: -1 + unk_index: -1 + normalize: false + trim: true + load_audio: true + parser: en + parser_add_end_space: false + add_misc: false + subword_sampling_nbest_size: null + subword_sampling_alpha: null +optim: + name: adamw + lr: 3.0e-05 + sched: + hold_steps: null + hold_ratio: 0.4 + warmup_steps: null + warmup_ratio: 0.1 + warmup_power: null + name: PolynomialHoldDecayAnnealing + min_lr: 1.5e-06 + last_epoch: -1 + max_steps: 80000 + power: 1.0 + cycle: false + betas: + - 0.9 + - 0.98 + eps: 1.0e-06 + weight_decay: 0.01 + amsgrad: false +pretrain_chkpt_path: /home/ma-user/work/taodehua/pretrained_models/SPIRAL_BASE_pretrain_multi_lingual_wenetspeech_13k960_40ms_init80_from_13k960_v2/st2vec-1.6M.ckpt +encoder_type: st +encoder: + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + window_size: 0.02 + window_stride: 0.01 + n_window_size: null + n_window_stride: null + window: hann + normalize: per_feature + n_fft: null + preemph: 0.97 + features: 128 + lowfreq: 0 + highfreq: null + log: true + log_zero_guard_type: add + log_zero_guard_value: 5.960464477539063e-08 + dither: 1.0e-05 + dither_train_only: true + pad_to: 16 + frame_splicing: 1 + stft_exact_pad: false + stft_conv: false + pad_value: 0 + mag_power: 2.0 + normalize_time_domain: true + feature_encoder: + _target_: nemo.collections.asr.parts.spec2vec.FeatureEncoder + feat_in: 128 + use_conv_mask: true + conv2d_block: null + conv_transformer_blocks: + - conv_layers: + - filters: 384 + kernel_size: + - 5 + stride: + - 2 + norm_type: ln + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.1 + padding: same + bias: true + residual: false + - filters: 512 + kernel_size: + - 5 + stride: + - 2 + norm_type: ln + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.1 + padding: same + bias: true + residual: false + - filters: 512 + kernel_size: + - 1 + stride: + - 1 + norm_type: ln + gn_groups: null + act_func: null + dilation: + - 1 + dropout: 0.0 + padding: same + bias: true + residual: false + transformer_block: + use_pytorch_transformer: false + dropout: 0.1 + conv: + conv_pos: 128 + conv_pos_groups: 16 + layer_drop: 0.0 + encoder: + encoder_layers: 2 + encoder_layerdrop: 0.0 + embedding_dim: 512 + ffn_embedding_dim: 2048 + num_attention_heads: 8 + dropout: 0.1 + attention_dropout: 0.1 + activation_dropout: 0.1 + activation_fn: gelu + layer_norm_first: true + - conv_layers: + - filters: 1536 + kernel_size: + - 5 + stride: + - 1 + norm_type: ln + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.1 + padding: same + bias: true + residual: false + - filters: 768 + kernel_size: + - 1 + stride: + - 1 + norm_type: ln + gn_groups: null + act_func: null + dilation: + - 1 + dropout: 0.0 + padding: same + bias: true + residual: false + transformer_block: + use_pytorch_transformer: false + dropout: 0.1 + conv: + conv_pos: 128 + conv_pos_groups: 16 + layer_drop: 0.0 + encoder: + encoder_layers: 10 + encoder_layerdrop: 0.1 + embedding_dim: 768 + ffn_embedding_dim: 3072 + num_attention_heads: 12 + dropout: 0.1 + attention_dropout: 0.1 + activation_dropout: 0.1 + activation_fn: gelu + layer_norm_first: true + use_tf_pad: true + ln_eps: 1.0e-05 + pretrained_encoder_path: null + freeze_feature_encoder: false + freeze_student: false + noise_mix_ratio: null + masking: + mask_prob: 0.0 + mask_type: static + mask_emb_type: gaussian + mask_other: 0 + mask_length: 20 + no_mask_overlap: false + mask_min_space: 1 + mask_channel_prob: 0.0 + mask_channel_type: static + mask_channel_other: 0 + mask_channel_length: 20 + no_mask_channel_overlap: false + mask_channel_min_space: 1 + mask_shrink_to_batch_min: false + mask_channel_shrink_to_batch_min: false + shifting: null + target_shifting: + dist: uniform + shift_prob: 1.0 + max_ratio: 0.5 + unit: 4 + max: 16 + min: 0 + mean: null + std: null + truncate: false + target_masking: null + target_compute_perturb: true + target_momentum: 0.99 + target_momentum_final: 0.999 + target_momentum_steps: 2000000 + target_momentum_type: cosine + projector: + input_dim: null + output_dim: 256 + use_conv_mask: true + use_tf_pad: true + ln_eps: 1.0e-05 + conv_layers: null + transformer: null + predictor: + input_dim: null + output_dim: 256 + use_conv_mask: true + use_tf_pad: true + ln_eps: 1.0e-05 + conv_layers: + - filters: 256 + kernel_size: + - 5 + stride: + - 1 + norm_type: bn + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.0 + padding: same + bias: null + residual: false + - filters: 256 + kernel_size: + - 5 + stride: + - 1 + norm_type: bn + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.0 + padding: same + bias: null + residual: false + transformer: null + quantizer: null + n_negatives: 100 + cross_sample_negatives: 0 + codebook_negatives: 0 + negatives_from_everywhere: false + negatives_from_noisy_features: false + pitch_estimation_task: null + pitch_loss_weight: 0.0 + reconstruction_task: null + reconstruction_loss_weight: 0.0 + reconstruction_quant_ppl_loss_weight: 0.0 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 4 + num_classes: 0 + proj_upsampling: null + conv_layers: + - filters: 512 + kernel_size: + - 5 + stride: + - 1 + norm_type: null + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.1 + padding: same + bias: null + residual: false + - filters: 512 + kernel_size: + - 5 + stride: + - 1 + norm_type: null + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.1 + padding: same + bias: null + residual: false + - filters: 512 + kernel_size: + - 5 + stride: + - 1 + norm_type: null + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.1 + padding: same + bias: null + residual: false + - filters: 512 + kernel_size: + - 5 + stride: + - 1 + norm_type: null + gn_groups: null + act_func: relu + dilation: + - 1 + dropout: 0.1 + padding: same + bias: null + residual: false + projector: null + use_conv_mask: true + use_tf_pad: true + ln_eps: 1.0e-05 + blank_pos: after_vocab_last + init_mode: xavier_uniform + vocabulary: + - en_AA0 + - en_AA1 + - en_AA2 + - en_AE0 + - en_AE1 + - en_AE2 + - en_AH0 + - en_AH1 + - en_AH2 + - en_AO0 + - en_AO1 + - en_AO2 + - en_AW0 + - en_AW1 + - en_AW2 + - en_AY0 + - en_AY1 + - en_AY2 + - en_B + - en_CH + - en_D + - en_DH + - en_EH0 + - en_EH1 + - en_EH2 + - en_ER0 + - en_ER1 + - en_ER2 + - en_EY0 + - en_EY1 + - en_EY2 + - en_F + - en_G + - en_HH + - en_IH0 + - en_IH1 + - en_IH2 + - en_IY0 + - en_IY1 + - en_IY2 + - en_JH + - en_K + - en_L + - en_M + - en_N + - en_NG + - en_OW0 + - en_OW1 + - en_OW2 + - en_OY0 + - en_OY1 + - en_OY2 + - en_P + - en_R + - en_S + - en_SH + - en_T + - en_TH + - en_UH0 + - en_UH1 + - en_UH2 + - en_UW0 + - en_UW1 + - en_UW2 + - en_V + - en_W + - en_Y + - en_Z + - en_ZH + - CN_a1 + - CN_a2 + - CN_a3 + - CN_a4 + - CN_a5 + - CN_ai1 + - CN_ai2 + - CN_ai3 + - CN_ai4 + - CN_an1 + - CN_an2 + - CN_an3 + - CN_an4 + - CN_ang1 + - CN_ang2 + - CN_ang3 + - CN_ang4 + - CN_ao1 + - CN_ao2 + - CN_ao3 + - CN_ao4 + - CN_b + - CN_c + - CN_ch + - CN_d + - CN_e1 + - CN_e2 + - CN_e3 + - CN_e4 + - CN_e5 + - CN_ei1 + - CN_ei2 + - CN_ei3 + - CN_ei4 + - CN_en1 + - CN_en2 + - CN_en3 + - CN_en4 + - CN_en5 + - CN_eng1 + - CN_eng2 + - CN_eng3 + - CN_eng4 + - CN_f + - CN_g + - CN_h + - CN_i1 + - CN_i2 + - CN_i3 + - CN_i4 + - CN_ia1 + - CN_ia2 + - CN_ia3 + - CN_ia4 + - CN_ian1 + - CN_ian2 + - CN_ian3 + - CN_ian4 + - CN_iang1 + - CN_iang2 + - CN_iang3 + - CN_iang4 + - CN_iao1 + - CN_iao2 + - CN_iao3 + - CN_iao4 + - CN_ie1 + - CN_ie2 + - CN_ie3 + - CN_ie4 + - CN_ii1 + - CN_ii2 + - CN_ii3 + - CN_ii4 + - CN_ii5 + - CN_iii1 + - CN_iii2 + - CN_iii3 + - CN_iii4 + - CN_in1 + - CN_in2 + - CN_in3 + - CN_in4 + - CN_ing1 + - CN_ing2 + - CN_ing3 + - CN_ing4 + - CN_iong1 + - CN_iong2 + - CN_iong3 + - CN_iong4 + - CN_iou1 + - CN_iou2 + - CN_iou3 + - CN_iou4 + - CN_j + - CN_k + - CN_l + - CN_m + - CN_n + - CN_o1 + - CN_o2 + - CN_o3 + - CN_o4 + - CN_ong1 + - CN_ong2 + - CN_ong3 + - CN_ong4 + - CN_ou1 + - CN_ou2 + - CN_ou3 + - CN_ou4 + - CN_p + - CN_q + - CN_r + - CN_rr + - CN_s + - CN_sh + - CN_t + - CN_u1 + - CN_u2 + - CN_u3 + - CN_u4 + - CN_ua1 + - CN_ua2 + - CN_ua3 + - CN_ua4 + - CN_uai1 + - CN_uai2 + - CN_uai3 + - CN_uai4 + - CN_uan1 + - CN_uan2 + - CN_uan3 + - CN_uan4 + - CN_uang1 + - CN_uang2 + - CN_uang3 + - CN_uang4 + - CN_ui1 + - CN_ui2 + - CN_ui3 + - CN_ui4 + - CN_un1 + - CN_un2 + - CN_un3 + - CN_un4 + - CN_uo1 + - CN_uo2 + - CN_uo3 + - CN_uo4 + - CN_uo5 + - CN_v1 + - CN_v2 + - CN_v3 + - CN_v4 + - CN_van1 + - CN_van2 + - CN_van3 + - CN_van4 + - CN_ve1 + - CN_ve2 + - CN_ve3 + - CN_ve4 + - CN_vn1 + - CN_vn2 + - CN_vn3 + - CN_vn4 + - CN_x + - CN_z + - CN_zh +labels: +- en_AA0 +- en_AA1 +- en_AA2 +- en_AE0 +- en_AE1 +- en_AE2 +- en_AH0 +- en_AH1 +- en_AH2 +- en_AO0 +- en_AO1 +- en_AO2 +- en_AW0 +- en_AW1 +- en_AW2 +- en_AY0 +- en_AY1 +- en_AY2 +- en_B +- en_CH +- en_D +- en_DH +- en_EH0 +- en_EH1 +- en_EH2 +- en_ER0 +- en_ER1 +- en_ER2 +- en_EY0 +- en_EY1 +- en_EY2 +- en_F +- en_G +- en_HH +- en_IH0 +- en_IH1 +- en_IH2 +- en_IY0 +- en_IY1 +- en_IY2 +- en_JH +- en_K +- en_L +- en_M +- en_N +- en_NG +- en_OW0 +- en_OW1 +- en_OW2 +- en_OY0 +- en_OY1 +- en_OY2 +- en_P +- en_R +- en_S +- en_SH +- en_T +- en_TH +- en_UH0 +- en_UH1 +- en_UH2 +- en_UW0 +- en_UW1 +- en_UW2 +- en_V +- en_W +- en_Y +- en_Z +- en_ZH +- CN_a1 +- CN_a2 +- CN_a3 +- CN_a4 +- CN_a5 +- CN_ai1 +- CN_ai2 +- CN_ai3 +- CN_ai4 +- CN_an1 +- CN_an2 +- CN_an3 +- CN_an4 +- CN_ang1 +- CN_ang2 +- CN_ang3 +- CN_ang4 +- CN_ao1 +- CN_ao2 +- CN_ao3 +- CN_ao4 +- CN_b +- CN_c +- CN_ch +- CN_d +- CN_e1 +- CN_e2 +- CN_e3 +- CN_e4 +- CN_e5 +- CN_ei1 +- CN_ei2 +- CN_ei3 +- CN_ei4 +- CN_en1 +- CN_en2 +- CN_en3 +- CN_en4 +- CN_en5 +- CN_eng1 +- CN_eng2 +- CN_eng3 +- CN_eng4 +- CN_f +- CN_g +- CN_h +- CN_i1 +- CN_i2 +- CN_i3 +- CN_i4 +- CN_ia1 +- CN_ia2 +- CN_ia3 +- CN_ia4 +- CN_ian1 +- CN_ian2 +- CN_ian3 +- CN_ian4 +- CN_iang1 +- CN_iang2 +- CN_iang3 +- CN_iang4 +- CN_iao1 +- CN_iao2 +- CN_iao3 +- CN_iao4 +- CN_ie1 +- CN_ie2 +- CN_ie3 +- CN_ie4 +- CN_ii1 +- CN_ii2 +- CN_ii3 +- CN_ii4 +- CN_ii5 +- CN_iii1 +- CN_iii2 +- CN_iii3 +- CN_iii4 +- CN_in1 +- CN_in2 +- CN_in3 +- CN_in4 +- CN_ing1 +- CN_ing2 +- CN_ing3 +- CN_ing4 +- CN_iong1 +- CN_iong2 +- CN_iong3 +- CN_iong4 +- CN_iou1 +- CN_iou2 +- CN_iou3 +- CN_iou4 +- CN_j +- CN_k +- CN_l +- CN_m +- CN_n +- CN_o1 +- CN_o2 +- CN_o3 +- CN_o4 +- CN_ong1 +- CN_ong2 +- CN_ong3 +- CN_ong4 +- CN_ou1 +- CN_ou2 +- CN_ou3 +- CN_ou4 +- CN_p +- CN_q +- CN_r +- CN_rr +- CN_s +- CN_sh +- CN_t +- CN_u1 +- CN_u2 +- CN_u3 +- CN_u4 +- CN_ua1 +- CN_ua2 +- CN_ua3 +- CN_ua4 +- CN_uai1 +- CN_uai2 +- CN_uai3 +- CN_uai4 +- CN_uan1 +- CN_uan2 +- CN_uan3 +- CN_uan4 +- CN_uang1 +- CN_uang2 +- CN_uang3 +- CN_uang4 +- CN_ui1 +- CN_ui2 +- CN_ui3 +- CN_ui4 +- CN_un1 +- CN_un2 +- CN_un3 +- CN_un4 +- CN_uo1 +- CN_uo2 +- CN_uo3 +- CN_uo4 +- CN_uo5 +- CN_v1 +- CN_v2 +- CN_v3 +- CN_v4 +- CN_van1 +- CN_van2 +- CN_van3 +- CN_van4 +- CN_ve1 +- CN_ve2 +- CN_ve3 +- CN_ve4 +- CN_vn1 +- CN_vn2 +- CN_vn3 +- CN_vn4 +- CN_x +- CN_z +- CN_zh +tokenizer: null +add_end_space: false +lang: en +freeze_finetune_updates: 0 +noise_perturb: null +expected_gpu_num: 8 +label_type: phone +quantizer: + quantize_targets: true + quantize_input: false + same_quantizer: false + targets_bottleneck_dim: null + targets_bottleneck_act_fn: null + targets_bottleneck_dropout: 0.0 + latent_vars: 320 + latent_groups: 2 + latent_dim: 0 + latent_temp: + - 2 + - 0.5 + - 0.999995 + levels: + - 8 + - 8 + - 8 + - 8 + l2_norm: false + batch_norm: false +quant_ppl_loss_weight: 1.0 +use_teacher_encoder: false +target: nemo.collections.asr.models.spec2vec.vq_ctc_finetune.VQCTCFinetuneModel diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_conf/st2vec_lfr_pretrain_maskp5cp4gaus_tp3_tgtshift16_preln_lr3e3_40ms_fp16_init80ms_multilingual_2.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_conf/st2vec_lfr_pretrain_maskp5cp4gaus_tp3_tgtshift16_preln_lr3e3_40ms_fp16_init80ms_multilingual_2.py new file mode 100644 index 0000000000000000000000000000000000000000..cfabc974f55e0980582fcb8a201dbb730a5bae45 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_conf/st2vec_lfr_pretrain_maskp5cp4gaus_tp3_tgtshift16_preln_lr3e3_40ms_fp16_init80ms_multilingual_2.py @@ -0,0 +1,251 @@ +# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from nemo.collections.asr.models.configs.common_config import AudioDatasetConfig, AdamWParams, \ + CosineAnnealingParams, Conv2dBlock, Conv2dNormAct, Conv1dNormAct +from nemo.collections.asr.models.spec2vec.spec2vec_config import FeatureEncoderConfig, \ + ConvTransformerBlock, ProjectorConfig +from nemo.collections.asr.models.st2vec.st2vec_config import ST2VecEncoderConfig, ST2VecPretrainModelConfig, \ + ShiftPerturbConfig +from nemo.collections.asr.models.wav2vec.wav2vec_config import Wav2VecTransformerEncoderConfig, \ + Wav2VecTransformerConfig, ConvConfig, Wav2VecActivationType, QuantizerConfig, Wav2VecMaskingConfig, Wav2VecMaskType, \ + LossConfig +from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessorConfig +from nemo.core.config import TrainerConfig +from nemo.core.config.modelPT import ModelPTConfig +from nemo.utils.exp_manager import ExpManagerConfig, CallbackParams + +config_name = 'st2vec' + +sample_rate = 16000 +num_features = 128 +#num_features = 129 +#max_steps=200000 +max_steps=2000000 +#lr=0.003 +lr=0.0003 + +st2vec_encoder = ST2VecEncoderConfig( + preprocessor=AudioToMelSpectrogramPreprocessorConfig( + normalize='per_feature', + sample_rate=sample_rate, + window_size=0.02, + window_stride=0.01, + window='hann', + features=num_features, + stft_conv=False, + dither_train_only=True, + normalize_time_domain=True, + # pitch=False, + ), + feature_encoder=FeatureEncoderConfig( + feat_in=num_features, + use_conv_mask=True, + use_tf_pad=True, + conv2d_block=None, + conv_transformer_blocks=[ConvTransformerBlock( + conv_layers=[Conv1dNormAct(filters=384, kernel_size=(5,), stride=(2,), + norm_type='ln', bias=True, dropout=0.1, + act_func='relu'), + Conv1dNormAct(filters=512, kernel_size=(5,), stride=(2,), + norm_type='ln', bias=True, dropout=0.1, + act_func='relu'), + Conv1dNormAct(filters=512, kernel_size=(1,), stride=(1,), + norm_type='ln', bias=True, + act_func=None), + ], + transformer_block=Wav2VecTransformerConfig( + use_pytorch_transformer=False, + dropout=0.1, + conv=ConvConfig( + conv_pos=128, + conv_pos_groups=16 + ), + encoder=Wav2VecTransformerEncoderConfig( + encoder_layers=2, + encoder_layerdrop=0.0, + embedding_dim=512, + ffn_embedding_dim=512 * 4, + num_attention_heads=8, + dropout=0.1, + activation_fn=Wav2VecActivationType.gelu, + layer_norm_first=True + ) + ), + ), + ConvTransformerBlock( + conv_layers=[Conv1dNormAct(filters=768 * 2, kernel_size=(5,), stride=(1,), + norm_type='ln', bias=True, dropout=0.1, + act_func='relu'), + Conv1dNormAct(filters=768, kernel_size=(1,), stride=(1,), + norm_type='ln', bias=True, + act_func=None), + ], + transformer_block=Wav2VecTransformerConfig( + use_pytorch_transformer=False, + dropout=0.1, + conv=ConvConfig( + conv_pos=128, + conv_pos_groups=16 + ), + encoder=Wav2VecTransformerEncoderConfig( + encoder_layers=10, + encoder_layerdrop=0.05, + embedding_dim=768, + ffn_embedding_dim=3072, + num_attention_heads=12, + dropout=0.1, + activation_fn=Wav2VecActivationType.gelu, + layer_norm_first=True + ) + ), + ), + ], + ), + masking=Wav2VecMaskingConfig( + mask_prob=0.5, + mask_type=Wav2VecMaskType.static, + mask_emb_type='gaussian', + mask_other=0, + mask_length=20, + no_mask_overlap=False, + mask_min_space=1, + mask_channel_prob=0.4, + mask_channel_type=Wav2VecMaskType.static, + mask_channel_other=0, + mask_channel_length=20, + no_mask_channel_overlap=False, + mask_channel_min_space=1, + mask_shrink_to_batch_min=False + ), + target_compute_perturb=True, + target_shifting=ShiftPerturbConfig( + dist='uniform', + shift_prob=1.0, + max_ratio=0.5, + unit=4, + max=16, + min=0, + truncate=False, + ), + target_momentum_type='cosine', + target_momentum=0.99, + target_momentum_final=0.999, + target_momentum_steps=max_steps, + projector=ProjectorConfig(output_dim=256), + predictor=ProjectorConfig( + conv_layers=[ + Conv1dNormAct(filters=256, kernel_size=(5,), stride=(1,), + norm_type='bn', + act_func='relu'), + Conv1dNormAct(filters=256, kernel_size=(5,), stride=(1,), + norm_type='bn', + act_func='relu'), + ], + output_dim=256 + ), + n_negatives=100, + cross_sample_negatives=0, + codebook_negatives=0, + negatives_from_everywhere=False, +) + +model = ST2VecPretrainModelConfig() + +model.st2vec_encoder = st2vec_encoder + +model.logit_temp = 0.3 +model.loss = LossConfig( + prob_ppl_weight=0.0 +) + + +model.train_ds = AudioDatasetConfig( + manifest_filepath='manifest_json/wenetspeech_train_2ms.json,manifest_json/librivox-train-clean-100.json,manifest_json/librivox-train-clean-360.json,manifest_json/librivox-train-other-500.json,manifest_json/librivox_13000h.json', + sample_rate=sample_rate, + batch_size=24, + min_duration=2.0, + crop_size=250000, + shuffle=True, + num_workers=4, + pin_memory=True, +) + +model.validation_ds = AudioDatasetConfig( + manifest_filepath='manifest_json/dev.json,manifest_json/librivox-dev-clean.json', + sample_rate=sample_rate, + batch_size=24, + min_duration=2.0, + crop_size=250000, + shuffle=False, + num_workers=4, +) + +model.test_ds = AudioDatasetConfig( + manifest_filepath='manifest_json/test_net.json,manifest_json/test_meeting.json,manifest_json/librivox-test-clean.json', + sample_rate=sample_rate, + batch_size=24, + min_duration=2.0, + crop_size=250000, + shuffle=False, + num_workers=4, +) + +model.expected_gpu_num = 16 +model.optim = AdamWParams( + lr=lr, + eps=1e-6, + betas=[0.9, 0.98], + weight_decay=0.01, + sched=CosineAnnealingParams( + min_lr=0.0, + warmup_steps=32000, + max_steps=max_steps, + ), +) + +trainer = TrainerConfig( + gpus=8, + #max_epochs=280, + max_epochs=140, + accelerator="ddp", + accumulate_grad_batches=1, + checkpoint_callback=False, # Provided by exp_manager + logger=False, # Provided by exp_manager + log_every_n_steps=50, + progress_bar_refresh_rate=50, + num_sanity_val_steps=0, + check_val_every_n_epoch=1, + num_nodes=1, + # + precision=16, + amp_backend="apex", + amp_level='O1', +) + +exp_manager = ExpManagerConfig( + name=config_name, + create_checkpoint_callback=True, + checkpoint_callback_params=CallbackParams( + save_top_k=5 + ) +) + +cfg = ModelPTConfig( + name=config_name, + model=model, + trainer=trainer, + exp_manager=exp_manager +) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_conf/st2vec_multi_lingual_wenetspeech_13k960_pretrained_aishell1_ls100_finetune_FSQ_8888_CTC_4ConvDec_phone_40ms_20240522.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_conf/st2vec_multi_lingual_wenetspeech_13k960_pretrained_aishell1_ls100_finetune_FSQ_8888_CTC_4ConvDec_phone_40ms_20240522.py new file mode 100644 index 0000000000000000000000000000000000000000..18e3a488039ff31722bd6e02709f07672fb1e974 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_conf/st2vec_multi_lingual_wenetspeech_13k960_pretrained_aishell1_ls100_finetune_FSQ_8888_CTC_4ConvDec_phone_40ms_20240522.py @@ -0,0 +1,192 @@ +# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.models.configs.common_config import AdamWParams, DatasetConfig, Conv1dNormAct, \ + PolynomialHoldDecayAnnealingParams, ProjUpsampling, Tokenizer +from nemo.collections.asr.models.configs.ctc_models_config import ConvASRDecoderConfig +from nemo.collections.asr.models.wav2vec.wav2vec_config import QuantizerConfig +from nemo.core.config import TrainerConfig +from nemo.core.config.modelPT import ModelPTConfig +from nemo.utils.exp_manager import ExpManagerConfig, CallbackParams +from nemo.collections.asr.models.spec2vec.spec2vec_config import ST2VecVQCTCFinetuneModelConfig + +config_name = 'vq_ctc_finetune' + +sample_rate = 16000 +num_features = 128 + +model = ST2VecVQCTCFinetuneModelConfig() + +# # English labels +# LABELS = ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', +# 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', +# 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', +# 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', +# 'Z', 'ZH'] +# # tonal phone number: 69 + +# # Chinese labels +# LABELS = ['a1', 'a2', 'a3', 'a4', 'a5', 'ai1', 'ai2', 'ai3', 'ai4', 'an1', 'an2', 'an3', 'an4', 'ang1', 'ang2', 'ang3', +# 'ang4', 'ao1', 'ao2', 'ao3', 'ao4', 'b', 'c', 'ch', 'd', 'e1', 'e2', 'e3', 'e4', 'e5', 'ei1', 'ei2', 'ei3', +# 'ei4', 'en1', 'en2', 'en3', 'en4', 'en5', 'eng1', 'eng2', 'eng3', 'eng4', 'f', 'g', 'h', 'i1', 'i2', 'i3', +# 'i4', 'ia1', 'ia2', 'ia3', 'ia4', 'ian1', 'ian2', 'ian3', 'ian4', 'iang1', 'iang2', 'iang3', 'iang4', 'iao1', +# 'iao2', 'iao3', 'iao4', 'ie1', 'ie2', 'ie3', 'ie4', 'ii1', 'ii2', 'ii3', 'ii4', 'ii5', 'iii1', 'iii2', 'iii3', +# 'iii4', 'in1', 'in2', 'in3', 'in4', 'ing1', 'ing2', 'ing3', 'ing4', 'iong1', 'iong2', 'iong3', 'iong4', +# 'iou1', 'iou2', 'iou3', 'iou4', 'j', 'k', 'l', 'm', 'n', 'o1', 'o2', 'o3', 'o4', 'ong1', 'ong2', 'ong3', +# 'ong4', 'ou1', 'ou2', 'ou3', 'ou4', 'p', 'q', 'r', 'rr', 's', 'sh', 't', 'u1', 'u2', 'u3', 'u4', 'ua1', 'ua2', +# 'ua3', 'ua4', 'uai1', 'uai2', 'uai3', 'uai4', 'uan1', 'uan2', 'uan3', 'uan4', 'uang1', 'uang2', 'uang3', +# 'uang4', 'ui1', 'ui2', 'ui3', 'ui4', 'un1', 'un2', 'un3', 'un4', 'uo1', 'uo2', 'uo3', 'uo4', 'uo5', 'v1', +# 'v2', 'v3', 'v4', 'van1', 'van2', 'van3', 'van4', 've1', 've2', 've3', 've4', 'vn1', 'vn2', 'vn3', 'vn4', 'x', +# 'z', 'zh'] +# # tonal phone number: 171 + +LABELS = ['en_AA0', 'en_AA1', 'en_AA2', 'en_AE0', 'en_AE1', 'en_AE2', 'en_AH0', 'en_AH1', 'en_AH2', 'en_AO0', 'en_AO1', + 'en_AO2', 'en_AW0', 'en_AW1', 'en_AW2', 'en_AY0', 'en_AY1', 'en_AY2', 'en_B', 'en_CH', 'en_D', 'en_DH', + 'en_EH0', + 'en_EH1', 'en_EH2', 'en_ER0', 'en_ER1', 'en_ER2', 'en_EY0', 'en_EY1', 'en_EY2', 'en_F', 'en_G', 'en_HH', + 'en_IH0', + 'en_IH1', 'en_IH2', 'en_IY0', 'en_IY1', 'en_IY2', 'en_JH', 'en_K', 'en_L', 'en_M', 'en_N', 'en_NG', 'en_OW0', + 'en_OW1', 'en_OW2', 'en_OY0', 'en_OY1', 'en_OY2', 'en_P', 'en_R', 'en_S', 'en_SH', 'en_T', 'en_TH', 'en_UH0', + 'en_UH1', 'en_UH2', 'en_UW0', 'en_UW1', 'en_UW2', 'en_V', 'en_W', 'en_Y', 'en_Z', 'en_ZH', + 'CN_a1', 'CN_a2', 'CN_a3', 'CN_a4', 'CN_a5', 'CN_ai1', 'CN_ai2', 'CN_ai3', 'CN_ai4', 'CN_an1', 'CN_an2', + 'CN_an3', + 'CN_an4', 'CN_ang1', 'CN_ang2', 'CN_ang3', 'CN_ang4', 'CN_ao1', 'CN_ao2', 'CN_ao3', 'CN_ao4', 'CN_b', 'CN_c', + 'CN_ch', 'CN_d', 'CN_e1', 'CN_e2', 'CN_e3', 'CN_e4', 'CN_e5', 'CN_ei1', 'CN_ei2', 'CN_ei3', 'CN_ei4', + 'CN_en1', + 'CN_en2', 'CN_en3', 'CN_en4', 'CN_en5', 'CN_eng1', 'CN_eng2', 'CN_eng3', 'CN_eng4', 'CN_f', 'CN_g', 'CN_h', + 'CN_i1', 'CN_i2', 'CN_i3', 'CN_i4', 'CN_ia1', 'CN_ia2', 'CN_ia3', 'CN_ia4', 'CN_ian1', 'CN_ian2', 'CN_ian3', + 'CN_ian4', 'CN_iang1', 'CN_iang2', 'CN_iang3', 'CN_iang4', 'CN_iao1', 'CN_iao2', 'CN_iao3', 'CN_iao4', + 'CN_ie1', + 'CN_ie2', 'CN_ie3', 'CN_ie4', 'CN_ii1', 'CN_ii2', 'CN_ii3', 'CN_ii4', 'CN_ii5', 'CN_iii1', 'CN_iii2', + 'CN_iii3', + 'CN_iii4', 'CN_in1', 'CN_in2', 'CN_in3', 'CN_in4', 'CN_ing1', 'CN_ing2', 'CN_ing3', 'CN_ing4', 'CN_iong1', + 'CN_iong2', 'CN_iong3', 'CN_iong4', 'CN_iou1', 'CN_iou2', 'CN_iou3', 'CN_iou4', 'CN_j', 'CN_k', 'CN_l', + 'CN_m', + 'CN_n', 'CN_o1', 'CN_o2', 'CN_o3', 'CN_o4', 'CN_ong1', 'CN_ong2', 'CN_ong3', 'CN_ong4', 'CN_ou1', 'CN_ou2', + 'CN_ou3', 'CN_ou4', 'CN_p', 'CN_q', 'CN_r', 'CN_rr', 'CN_s', 'CN_sh', 'CN_t', 'CN_u1', 'CN_u2', 'CN_u3', + 'CN_u4', + 'CN_ua1', 'CN_ua2', 'CN_ua3', 'CN_ua4', 'CN_uai1', 'CN_uai2', 'CN_uai3', 'CN_uai4', 'CN_uan1', 'CN_uan2', + 'CN_uan3', + 'CN_uan4', 'CN_uang1', 'CN_uang2', 'CN_uang3', 'CN_uang4', 'CN_ui1', 'CN_ui2', 'CN_ui3', 'CN_ui4', 'CN_un1', + 'CN_un2', 'CN_un3', 'CN_un4', 'CN_uo1', 'CN_uo2', 'CN_uo3', 'CN_uo4', 'CN_uo5', 'CN_v1', 'CN_v2', 'CN_v3', + 'CN_v4', + 'CN_van1', 'CN_van2', 'CN_van3', 'CN_van4', 'CN_ve1', 'CN_ve2', 'CN_ve3', 'CN_ve4', 'CN_vn1', 'CN_vn2', + 'CN_vn3', + 'CN_vn4', 'CN_x', 'CN_z', 'CN_zh'] +# tonal phone number: 69 + 171 = 240 + + +model.labels = LABELS +model.label_type = 'phone' # one of ['char', 'phone','bpe'] +model.add_end_space = False +model.tokenizer = None # if tokenizer is not None, use BPE + +from my_conf.st2vec_lfr_pretrain_maskp5cp4gaus_tp3_tgtshift16_preln_lr3e3_40ms_fp16_init80ms_multilingual_2 import \ + st2vec_encoder + +encoder = st2vec_encoder +encoder.masking.mask_prob = 0 # actually no mask +encoder.masking.mask_channel_prob = 0 + +transformer0 = encoder.feature_encoder.conv_transformer_blocks[-2].transformer_block +transformer0.encoder.activation_dropout = 0.1 +transformer0.encoder.dropout = 0.1 +transformer = encoder.feature_encoder.conv_transformer_blocks[-1].transformer_block +transformer.encoder.encoder_layerdrop = 0.1 +transformer.encoder.activation_dropout = 0.1 +transformer.encoder.dropout = 0.1 + +model.encoder = encoder + +# encoder->quantizer->decoder->CTC + +# For finite scalar quantizer +# Quantization level of each dimension +quant_level_per_dim = [8, 8, 8, 8] +model.quantizer = QuantizerConfig( + levels=quant_level_per_dim, + l2_norm=False, + batch_norm=False, +) + +# enc_output_dim = transformer.encoder.embedding_dim # 768 + +# feat_in=quantizer_latent_dim, +model.decoder = ConvASRDecoderConfig( + feat_in=len(quant_level_per_dim), + # proj_upsampling=ProjUpsampling(rate=1, filters=512, kernel_size=(5,), norm_type='ln', act_func='relu', dropout=0.1), + conv_layers=[Conv1dNormAct(filters=512, kernel_size=(5,), stride=(1,), + norm_type=None, dropout=0.1, + act_func='relu'), + Conv1dNormAct(filters=512, kernel_size=(5,), stride=(1,), + norm_type=None, dropout=0.1, + act_func='relu'), + Conv1dNormAct(filters=512, kernel_size=(5,), stride=(1,), + norm_type=None, dropout=0.1, + act_func='relu'), + Conv1dNormAct(filters=512, kernel_size=(5,), stride=(1,), + norm_type=None, dropout=0.1, + act_func='relu'), + ], + vocabulary=LABELS, + blank_pos='after_vocab_last' +) + +model.quant_ppl_loss_weight = 1 +_batch_size = 14 # 4 + +model.expected_gpu_num = 8 +lr = 0.00003 +model.optim = AdamWParams( + lr=lr, + eps=1e-6, + betas=[0.9, 0.98], + weight_decay=0.01, + sched=PolynomialHoldDecayAnnealingParams( + min_lr=lr * 0.05, + warmup_ratio=0.1, + hold_ratio=0.4, + max_steps=80000, + ), +) + +trainer = TrainerConfig( + gpus=8, + max_epochs=320, + accelerator='ddp', + accumulate_grad_batches=1, + checkpoint_callback=False, # Provided by exp_manager + logger=False, # Provided by exp_manager + log_every_n_steps=50, + progress_bar_refresh_rate=50, + num_sanity_val_steps=0, + check_val_every_n_epoch=1 +) + +exp_manager = ExpManagerConfig( + name=config_name, + create_checkpoint_callback=True, + checkpoint_callback_params=CallbackParams( + monitor="val_wer", + mode="min", + save_top_k=5 + ) +) + +cfg = ModelPTConfig( + name=config_name, + model=model, + trainer=trainer, + exp_manager=exp_manager +) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_extract_unit_for_speech/extract_unit_construct_wav_unit_text.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_extract_unit_for_speech/extract_unit_construct_wav_unit_text.py new file mode 100644 index 0000000000000000000000000000000000000000..8f9927c0c1ec32659b60ddb5fb497db786290e6d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_extract_unit_for_speech/extract_unit_construct_wav_unit_text.py @@ -0,0 +1,132 @@ +from pathlib import Path +from tqdm import tqdm +import torch +import numpy as np +import librosa + +try: + import torch_npu + from torch_npu.npu import amp + from torch_npu.contrib import transfer_to_npu + print('Successful import torch_npu') +except Exception as e: + print(e) + +default_sample_rate = 16000 + +def read_audio(path): + # SPIRAL model can only consume audio with sample_rate of 16000, resample may be carried out + wav, origin_sample_rate = librosa.load(path, sr=None) + if origin_sample_rate != default_sample_rate: + wav = librosa.resample(wav, orig_sr=origin_sample_rate, target_sr=default_sample_rate) + assert wav.ndim == 1, wav.ndim + return wav + +def extract_feature(model, audio, audio_length, enable_vq=True): + encoded, encoded_len = model.encoder(audio, audio_length, None, None, mask=False, features_only=True) # [B, T, D] + if enable_vq: + # For finte sclar quantizer + encoded = model.pre_quant(encoded) + q_feat = model.quantizer(encoded) + q_feat_ids = model.quantizer.codes_to_indexes(q_feat) + q_feat_ids = q_feat_ids.to(torch.int) + return q_feat_ids, encoded_len + return encoded, encoded_len + +# reduced_unit_sequence surpasses duplicate patterns +def sample_extract_unit(wav_path, model, show_result=False): + # load audio + audio = read_audio(wav_path) + x = torch.from_numpy(audio) + max_chunk = 1600000 + x_len = [x.shape[0]] + device = next(model.parameters()).device # we assume speech tokenizer is stored in a single device + with torch.no_grad(): + x = x.float().to(device) + x = x.view(1, -1) + x_len = torch.from_numpy(np.array(x_len)).int().to(device) + feat = [] + for start in range(0, x.size(1), max_chunk): + x_chunk = x[:, start: start + max_chunk] + feat_chunk, _ = extract_feature(model, x_chunk, x_len) + feat.append(feat_chunk) + # return T x C + unit_list = torch.cat(feat, 1).squeeze(0).cpu().numpy().tolist() + reduced_unit_list = [] + prev_unit = None + for unit in unit_list: + if unit != prev_unit: + reduced_unit_list.append(unit) + prev_unit = unit + unit_sequence = ' '.join([str(x) for x in unit_list]) + reduced_unit_sequence = ' '.join([str(x) for x in reduced_unit_list]) + if show_result: + print('unit_sequence:') + print(unit_sequence) + print('reduced_unit_sequence:') + print(reduced_unit_sequence) + return unit_sequence, reduced_unit_sequence + +def batch_extract_unit(wav_file_list, model, show_result=False, max_chunk = 480000 ): + # in this setting, we dont chunck the wav, but to process by a batch + # max_chunk = 480000 limit to 30 s to avoid out-of-memory issue + # load audio + audio_list = [] + audio_len_list = [] + skip_audio_num = 0 + extracted_wav_file_list = [] + skipped_wav_file_list = [] + for wav_file in wav_file_list: + audio = read_audio(wav_file) + audio_len = audio.shape[0] + if audio_len > max_chunk: + print( + f'x is too long, x_len {audio_len} is longer than max_chunk {max_chunk}, skip it and extract later') + skip_audio_num += 1 + skipped_wav_file_list.append(wav_file) + continue # not to stop the extraction + extracted_wav_file_list.append(wav_file) + audio_list.append(audio) + audio_len_list.append(audio_len) + actual_batch_size = len(extracted_wav_file_list) # Actual batch size after removal of too-long audio + if actual_batch_size == 0: + return [], skipped_wav_file_list, [], [] + + device = next(model.parameters()).device # we assume speech tokenizer is stored in a single device + max_len = max(audio_len_list) + batch_audio = torch.zeros([actual_batch_size, max_len]).float().to(device) + for i in range(actual_batch_size): + batch_audio[i, :audio_len_list[i]] = torch.from_numpy(audio_list[i]) # stack audio of different length + batch_audio_len = torch.from_numpy(np.array(audio_len_list)).int().to(device) + with torch.no_grad(): + batch_feat, batch_feat_len = extract_feature(model, batch_audio, batch_audio_len) + # return B x T x C + unit_sequence_list, reduced_unit_sequence_list = [], [] + for i in range(len(batch_feat_len)): + feat_len = batch_feat_len[i] + feat = batch_feat[i][:feat_len] # Todo: should check! + unit_list = feat.cpu().numpy().tolist() # Todo: squeeze 0 should be removed, because we have a batch + reduced_unit_list = [] + prev_unit = None + for unit in unit_list: + if unit != prev_unit: + reduced_unit_list.append(unit) + prev_unit = unit + unit_sequence = ' '.join([str(x) for x in unit_list]) + reduced_unit_sequence = ' '.join([str(x) for x in reduced_unit_list]) + unit_sequence_list.append(unit_sequence) + reduced_unit_sequence_list.append(reduced_unit_sequence) + if show_result: + print('unit_sequence:') + print(unit_sequence) + print('reduced_unit_sequence:') + print(reduced_unit_sequence) + return extracted_wav_file_list, skipped_wav_file_list, unit_sequence_list, reduced_unit_sequence_list + +def get_S2U_ckpt_config_path(unit_type, language='English'): + assert language in ['English', 'Chinese'] + assert unit_type == '40ms_multilingual_8888' + # English and Chinese using the same SPIRAL model and config!! + ckpt_path = "./speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_S2U_model/SPIRAL2_base_mutilingual_wenet_lv13k960_pretrain_aishell1_ls100_finetune_FSQ_8888_CTC_4ConvDec_phone_40ms_1n8g_20240522/checkpoints/vq_ctc_finetune--val_wer=0.0293-epoch=316.ckpt" + config_path = "my_conf/st2vec_multi_lingual_wenetspeech_13k960_pretrained_aishell1_ls100_finetune_FSQ_8888_CTC_4ConvDec_phone_40ms_20240522" + return ckpt_path, config_path \ No newline at end of file diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_scripts/fsq.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_scripts/fsq.py new file mode 100644 index 0000000000000000000000000000000000000000..88a13c9baa158146071a1cb9c0212d114c666e81 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/my_scripts/fsq.py @@ -0,0 +1,116 @@ +# --------------------------------------------------------------------------------- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------------------- + +# This script is modified from https://github.com/google-research/google-research/blob/master/fsq/fsq.ipynb +# by converting jax to torch. + +import numpy as np +import torch +import torch.nn as nn +from nemo.core import NeuralModule + +Codeword = torch.FloatTensor +Indices = torch.FloatTensor + + +def round_ste(z): + """Round with straight through gradients.""" + zhat = torch.round(z) + return z + (zhat - z).detach() + + +class FSQ(NeuralModule): + """Quantizer.""" + + def __init__(self, levels: list, eps: float = 1e-3, l2_norm: bool = False, batch_norm: bool = False): + super().__init__() + + self._levels = levels + self._eps = eps + self.l2_norm = l2_norm + self.batch_norm = batch_norm + # self._levels_np = torch.Tensor(levels) + # self._basis = torch.cat((torch.Tensor([1]), torch.cumprod(self._levels_np[:-1], dim=0))) + self.register_buffer("_levels_np", torch.Tensor(levels)) + self.register_buffer("_basis", torch.cat((torch.Tensor([1]), torch.cumprod(self._levels_np[:-1], dim=0)))) + + self._implicit_codebook = self.indexes_to_codes(torch.arange(self.codebook_size)) + + if self.batch_norm: + self.bn = nn.BatchNorm1d(self.num_dimensions, momentum=0.01, eps=1e-3) + + @property + def num_dimensions(self) -> int: + """Number of dimensions expected from inputs.""" + return len(self._levels) + + @property + def codebook_size(self) -> int: + """Size of the codebook.""" + return np.prod(self._levels) + + @property + def codebook(self): + """Returns the implicit codebook. Shape (prod(levels), num_dimensions).""" + return self._implicit_codebook + + def bound(self, z: torch.FloatTensor) -> torch.FloatTensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels_np - 1) * (1 - self._eps) / 2 + offset = torch.where(self._levels_np % 2 == 1, 0.0, 0.5) + shift = torch.tan(offset / half_l) + return torch.tanh(z + shift) * half_l - offset + + def quantize(self, z: torch.FloatTensor) -> Codeword: + """Quanitzes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + + # Renormalize to [-1, 1]. + half_width = torch.div(self._levels_np, 2, rounding_mode='floor') + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized): + # Scale and shift to range [0, ..., L-1] + half_width = torch.div(self._levels_np, 2, rounding_mode='floor') + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + # Note that array(x) // 2 != tensor(x) // 2 when x is negative + half_width = torch.div(self._levels_np, 2, rounding_mode='floor') + return (zhat - half_width) / half_width + + def codes_to_indexes(self, zhat: Codeword) -> Indices: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.num_dimensions + zhat = self._scale_and_shift(zhat) + return torch.sum(zhat * self._basis, axis=-1) + + def indexes_to_codes(self, indices: Indices) -> Codeword: + """Inverse of `indexes_to_codes`.""" + indices = indices.unsqueeze(-1) + codes_non_centered = torch.remainder( + torch.div(indices, self._basis, rounding_mode='floor'), self._levels_np + ) + return self._scale_and_shift_inverse(codes_non_centered) + + def forward(self, z: torch.FloatTensor) -> Codeword: + # z.shape: [batch_size, seq_len, feat_size] + if self.l2_norm: + z = nn.functional.normalize(z, p=2, dim=-1) + + zhat = self.quantize(z) + + return zhat \ No newline at end of file diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/README.md b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2db456547ab9ba648063be0711247a98aecde594 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/README.md @@ -0,0 +1,9 @@ +NeMo (**Ne**ural **Mo**dules) is a toolkit for creating AI applications built around **neural modules**, conceptual blocks of neural networks that take *typed* inputs and produce *typed* outputs. + +**NeMo Core** provides common APIs all modules and models have to implement. + +**NeMo Collections** + +* ASR - collection of modules and models for building speech recognition networks +* TTS - collection of modules and models for building speech synthesis networks +* NLP - collection of modules and models for building NLP networks diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70b372752d04027463950617947aa3cf815fa2fb --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from .package_info import ( + __contact_emails__, + __contact_names__, + __description__, + __download_url__, + __homepage__, + __keywords__, + __license__, + __package_name__, + __repository_url__, + __shortversion__, + __version__, +) + +if "NEMO_PACKAGE_BUILDING" not in os.environ: + from nemo import collections, core, utils diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3250071955216f6abc505e6181fb59931baa8d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd14a4395006b23564424aa7d2f90a1069d5894f --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr import data, losses, models, modules +from nemo.package_info import __version__ + +# Set collection version equal to NeMo version. +__version = __version__ + +# Authorship. +__author__ = "NVIDIA Corporation" + +# Set collection name. +__description__ = "Automatic Speech Recognition collection" diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d8200ead58ae4586688a1c6478c054aed52d0f --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss +from nemo.collections.asr.losses.ctc import CTCLoss diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/angularloss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/angularloss.py new file mode 100644 index 0000000000000000000000000000000000000000..e2aee9bba6eaa391573acb70be1bfcb53e2940bd --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/angularloss.py @@ -0,0 +1,68 @@ +# ! /usr/bin/python +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.core.classes import Loss, Typing, typecheck +from nemo.core.neural_types import LabelsType, LogitsType, LossType, NeuralType + +__all__ = ['AngularSoftmaxLoss'] + + +class AngularSoftmaxLoss(Loss, Typing): + """ + Computes ArcFace Angular softmax angle loss + reference: https://openaccess.thecvf.com/content_CVPR_2019/papers/Deng_ArcFace_Additive_Angular_Margin_Loss_for_Deep_Face_Recognition_CVPR_2019_paper.pdf + args: + scale: scale value for cosine angle + margin: margin value added to cosine angle + """ + + @property + def input_types(self): + """Input types definitions for AnguarLoss. + """ + return { + "logits": NeuralType(('B', 'D'), LogitsType()), + "labels": NeuralType(('B',), LabelsType()), + } + + @property + def output_types(self): + """Output types definitions for AngularLoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, scale=20.0, margin=1.35): + super().__init__() + + self.eps = 1e-7 + self.scale = scale + self.margin = margin + + @typecheck() + def forward(self, logits, labels): + numerator = self.scale * torch.cos( + torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps)) + + self.margin + ) + excl = torch.cat( + [torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0 + ) + denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1) + L = numerator - torch.log(denominator) + return -torch.mean(L) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/ctc.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..eaec18a71378c37cdd6577363240cdf12c3a1122 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/ctc.py @@ -0,0 +1,127 @@ +# ! /usr/bin/python +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from nemo.core.classes import Serialization, Typing, typecheck +from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType +from nemo.utils.decorators import experimental + +__all__ = ['CTCLoss'] + + +@experimental +class CTCLoss(nn.CTCLoss, Serialization, Typing): + @property + def input_types(self): + """Input types definitions for CTCLoss. + """ + return { + "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), + "targets": NeuralType(('B', 'T'), LabelsType()), + "input_lengths": NeuralType(tuple('B'), LengthsType()), + "target_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Output types definitions for CTCLoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, blank_id, zero_infinity=False, reduction='mean_batch'): + self._blank = blank_id + # Don't forget to properly call base constructor + if reduction == 'mean_batch': + ctc_reduction = 'none' + self._apply_batch_mean = True + elif reduction in ['sum', 'mean', 'none']: + ctc_reduction = reduction + self._apply_batch_mean = False + super().__init__(blank=self._blank, reduction=ctc_reduction, zero_infinity=zero_infinity) + + @typecheck() + def forward(self, log_probs, targets, input_lengths, target_lengths): + # override forward implementation + # custom logic, if necessary + input_lengths = input_lengths.long() + target_lengths = target_lengths.long() + targets = targets.long() + # here we transpose because we expect [B, T, D] while PyTorch assumes [T, B, D] + log_probs = log_probs.transpose(1, 0) + loss = super().forward( + log_probs=log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths + ) + if self._apply_batch_mean: + loss = torch.mean(loss) + return loss + + +# Below is how "custom" loss should work +# @experimental +# class CTCLoss(Loss): +# """ +# CTCLoss +# Args: +# num_classes (int): Number of characters in ASR model's vocab/labels. +# This count should not include the CTC blank symbol. +# zero_infinity (bool): Whether to zero infinite losses and the associated gradients. +# By default, it is False. Infinite losses mainly occur when the inputs are too +# short to be aligned to the targets. +# """ +# +# def save_to(self, save_path: str): +# pass +# +# @classmethod +# def restore_from(cls, restore_path: str): +# pass +# +# @property +# def input_types(self): +# """Input types definitions for CTCLoss. +# """ +# return { +# "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), +# "targets": NeuralType(('B', 'T'), LabelsType()), +# "input_length": NeuralType(tuple('B'), LengthsType()), +# "target_length": NeuralType(tuple('B'), LengthsType()), +# } +# +# @property +# def output_types(self): +# """Output types definitions for CTCLoss. +# loss: +# NeuralType(None) +# """ +# return {"loss": NeuralType(elements_type=LossType())} +# +# def __init__(self, num_classes, zero_infinity=False): +# super().__init__() +# +# self._blank = num_classes +# self._criterion = nn.CTCLoss(blank=self._blank, reduction='none', zero_infinity=zero_infinity) +# +# @typecheck() +# def forward(self, log_probs, targets, input_length, target_length): +# input_length = input_length.long() +# target_length = target_length.long() +# targets = targets.long() +# loss = self._criterion(log_probs.transpose(1, 0), targets, input_length, target_length) +# loss = torch.mean(loss) +# return loss diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/rnnt.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/rnnt.py new file mode 100644 index 0000000000000000000000000000000000000000..db444b85029491d21b1217e0f92d8199e287a659 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/rnnt.py @@ -0,0 +1,306 @@ +# ! /usr/bin/python +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2018-2019, Mingkun Huang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +from dataclasses import dataclass +from typing import Optional + +import torch +from omegaconf import DictConfig, OmegaConf + +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType +from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE +from nemo.utils import logging, model_utils + +try: + import warprnnt_pytorch as warprnnt + + WARP_RNNT_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + WARP_RNNT_AVAILABLE = False + +try: + from nemo.collections.asr.parts.numba.rnnt_loss import RNNTLossNumba + + NUMBA_RNNT_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + NUMBA_RNNT_AVAILABLE = False + + +WARP_RNNT_INSTALLATION_MESSAGE = ( + "Could not import `warprnnt_pytorch`.\n" + "Please visit https://github.com/HawkAaron/warp-transducer " + "and follow the steps in the readme to build and install the " + "pytorch bindings for RNNT Loss, or use the provided docker " + "container that supports RNN-T loss." +) + + +@dataclass +class RNNTLossConfig: + loss_name: str + lib_name: str + is_available: bool = False + installation_msg: str = "" + min_version: Optional[str] = None + + +# Resolved list of available RNNT losses +RNNT_LOSS_RESOLVER = { + "warprnnt": RNNTLossConfig( + loss_name="warprnnt", + lib_name="warprnnt_pytorch", + is_available=WARP_RNNT_AVAILABLE, + installation_msg=WARP_RNNT_INSTALLATION_MESSAGE, + ), + "warprnnt_numba": RNNTLossConfig( + loss_name="warprnnt_numba", + lib_name="numba", + min_version='0.53.0', + is_available=NUMBA_RNNT_AVAILABLE, + installation_msg=NUMBA_INSTALLATION_MESSAGE, + ), +} + +RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt'] + + +def _warn_unused_additional_kwargs(loss_name, kwargs): + if len(kwargs) > 0: + logging.warning( + f"Loss function `{loss_name}` was provided with following additional kwargs,\n" + f"however they were ignored as it is unused.\n" + f"{kwargs}" + ) + + +def resolve_rnnt_default_loss_name() -> str: + return RNNT_LOSS_RESOLVER['default'].loss_name + + +def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) -> torch.nn.Module: + loss_function_names = list(RNNT_LOSS_RESOLVER.keys()) + + if loss_name not in loss_function_names: + raise ValueError( + f"Provided `loss_name` {loss_name} not in list of available RNNT losses \n" f"{loss_function_names}" + ) + + all_available_losses = {name: config for name, config in RNNT_LOSS_RESOLVER.items() if config.is_available} + + loss_config = RNNT_LOSS_RESOLVER[loss_name] # type: RNNTLossConfig + + # Re-raise import error with installation message + if not loss_config.is_available: + msg = ( + f"Installed RNNT losses are : {list(all_available_losses.keys())}.\n" + f"****************************************************************\n" + f"To install the selected loss function, please follow the steps below:\n" + f"{loss_config.installation_msg}" + ) + raise ImportError(msg) + + # Library version check + if loss_config.min_version is not None: + ver_matched, msg = model_utils.check_lib_version( + loss_config.lib_name, checked_version=loss_config.min_version, operator=operator.ge + ) + + if ver_matched is False: + msg = ( + f"{msg}\n" + f"****************************************************************\n" + f"To update the selected loss function, please follow the steps below:\n" + f"{loss_config.installation_msg}" + ) + raise RuntimeError(msg) + + # Resolve loss functions sequentially + loss_kwargs = {} if loss_kwargs is None else loss_kwargs + + if isinstance(loss_kwargs, DictConfig): + loss_kwargs = OmegaConf.to_container(loss_kwargs, resolve=True) + + # Get actual loss name for `default` + if loss_name == 'default': + loss_name = loss_config.loss_name + + """ + Resolve RNNT loss functions + """ + if loss_name == 'warprnnt': + loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none') + _warn_unused_additional_kwargs(loss_name, loss_kwargs) + + elif loss_name == 'warprnnt_numba': + fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0) + loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda) + _warn_unused_additional_kwargs(loss_name, loss_kwargs) + + else: + raise ValueError( + f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}" + ) + + return loss_func + + +class RNNTLoss(Loss): + @property + def input_types(self): + """Input types definitions for CTCLoss. + """ + return { + "log_probs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), + "targets": NeuralType(('B', 'T'), LabelsType()), + "input_lengths": NeuralType(tuple('B'), LengthsType()), + "target_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Output types definitions for CTCLoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, blank_idx, reduction: str = 'mean_batch', loss_name: str = "default", loss_kwargs=None): + """ + RNN-T Loss function based on https://github.com/HawkAaron/warp-transducer. + Optionally, can utilize a numba implementation of the same loss without having to compile the loss, + albiet there is a small speed penalty for JIT numba compile. + + Note: + Requires Numba 0.53.0 or later to be installed to use this loss function. + + Losses can be selected via the config, and optionally be passed keyword arguments as follows. + + Examples: + .. code-block:: yaml + + model: # RNNT Model config + ... + loss: + loss_name: "warprnnt_numba" + warprnnt_numba_kwargs: + fastemit_lambda: 0.0 + + Warning: + In the case that GPU memory is exhausted in order to compute RNNTLoss, it might cause + a core dump at the cuda level with the following error message. + + ``` + ... + costs = costs.to(acts.device) + RuntimeError: CUDA error: an illegal memory access was encountered + terminate called after throwing an instance of 'c10::Error' + ``` + + Please kill all remaining python processes after this point, and use a smaller batch size + for train, validation and test sets so that CUDA memory is not exhausted. + + Args: + blank_idx: Number of target classes for the joint network to predict. + (Excluding the RNN-T blank token). + + reduction: Type of reduction to perform on loss. Possibly values are `mean`, `sum` or None. + None will return a torch vector comprising the individual loss values of the batch. + + loss_name: String that is resolved into an RNNT loss function. Available list of losses + is ininitialized in `RNNT_LOSS_RESOLVER` dictionary. + + loss_kwargs: Optional Dict of (str, value) pairs that are passed to the instantiated loss + function. + """ + super(RNNTLoss, self).__init__() + + if reduction not in [None, 'mean', 'sum', 'mean_batch']: + raise ValueError('`reduction` must be one of [mean, sum, mean_batch]') + + self._blank = blank_idx + self.reduction = reduction + self._loss = resolve_rnnt_loss(loss_name, blank_idx=self._blank, loss_kwargs=loss_kwargs) + + @typecheck() + def forward(self, log_probs, targets, input_lengths, target_lengths): + # Cast to int 32 + targets = targets.int() + input_lengths = input_lengths.int() + target_lengths = target_lengths.int() + + # max_logit_len = input_lengths.max() + max_targets_len = target_lengths.max() + + # Force cast joint to float32 + # TODO: Remove once Numba supports FP16 + if log_probs.dtype != torch.float32: + logits_orig = log_probs + log_probs = log_probs.float() + del logits_orig # save memory *before* computing the loss + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + # of the log_probs tensor, therefore we increment the input_lengths by the difference. + # This difference is generally small. + # if log_probs.shape[1] != max_logit_len: + # log_probs = log_probs.narrow(dim=1, start=0, length=max_logit_len).contiguous() + + # Reduce transcript length to correct alignment if additional padding was applied. + # Transcript: [B, L] -> [B, L']; If L' < L + if targets.shape[1] != max_targets_len: + raise ValueError('targets len do not match targets') + targets = targets.narrow(dim=1, start=0, length=max_targets_len) + + # Loss reduction can be dynamic, so set it prior to call + if self.reduction != 'mean_batch': + self._loss.reduction = self.reduction + + # Compute RNNT loss + loss = self._loss(acts=log_probs, labels=targets, act_lens=input_lengths, label_lens=target_lengths) + + # Loss reduction can be dynamic, so reset it after call + if self.reduction != 'mean_batch': + self._loss.reduction = 'none' + + # Loss reduction only for mean_batch mode + if self.reduction == 'mean_batch': + loss = torch.mean(loss) + + # del new variables that may have been created + del ( + log_probs, + targets, + input_lengths, + target_lengths, + ) + + return loss diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/similarityloss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/similarityloss.py new file mode 100644 index 0000000000000000000000000000000000000000..195cc26b19b9abeb0e20181a5c28fee3344bd3ef --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/similarityloss.py @@ -0,0 +1,16 @@ +import torch + +from nemo.core import Loss + + +class NegativeCosineSimilarityLoss(Loss): + + def __init__(self, reduction: str = 'mean'): + super().__init__() + assert reduction == 'mean' + self.reduction = reduction + + def forward(self, predictions: torch.tensor, targets: torch.tensor): + similarity_scores = torch.cosine_similarity(predictions.float(), targets.float(), dim=-1).type_as(predictions) + loss = 1.0 - similarity_scores.mean() + return loss diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/wav2vecloss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/wav2vecloss.py new file mode 100644 index 0000000000000000000000000000000000000000..bff55e5554743387f6472cf26599ad315232c80e --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/losses/wav2vecloss.py @@ -0,0 +1,119 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + +from nemo.core import Loss, typecheck +from nemo.core.neural_types import EncodedRepresentation, LossType, NeuralType + + +class Wav2VecLoss(Loss): + + def __init__(self, feature_loss_weight: float, prob_ppl_weight: float, logit_temp: float, reduction: str = 'mean'): + """ + Compute the contrastive loss with respect to the model outputs and sampled negatives from quantizer codebooks. + Args: + feature_loss_weight: Feature penalty weight (L2 Norm) + prob_ppl_weight: Perplexity Loss with respect to probabilities during quantization + logit_temp: Temperature normalization applied in loss. + reduction: Reduce loss via sum reduction (Default true) + """ + super().__init__() + self.feature_loss_weight = feature_loss_weight + self.prob_ppl_weight = prob_ppl_weight + self.logit_temp = logit_temp + assert reduction in ['mean', 'sum'] + self.reduction = reduction + + def forward( + self, + logits: torch.tensor, + targets: torch.tensor, + negatives: torch.tensor, + prob_ppl_loss: torch.tensor, + feature_loss: torch.tensor, + compute_accuracy: bool + ) -> [torch.tensor, torch.tensor, torch.tensor]: + """ + Args: + logits: Model activations + targets: The true target quantized representations + negatives: Sampled negatives from the quantizer codebooks. Sampled from all other timesteps. + feature_loss: Feature penalty (L2 Norm) + prob_ppl_loss: Perplexity Loss with respect to probs in quantization + Returns: + output loss values, feature loss, prob_ppl loss (after scaling). + """ + + # Calculate similarity between logits and all targets, returning FxBxT + similarity_scores = self._calculate_similarity(logits, negatives, targets) + + # Create targets of size B*T + similarity_targets = logits.new_zeros(similarity_scores.size(1) * similarity_scores.size(2), dtype=torch.long) + + # Transpose similarity scores to (T*B)xF for loss + similarity_scores = similarity_scores.transpose(0, 2) + similarity_scores = similarity_scores.reshape(-1, similarity_scores.size(-1)) + + contrastive_loss = F.cross_entropy(similarity_scores, similarity_targets, reduction=self.reduction) + loss = contrastive_loss + + sample_size = similarity_targets.numel() + + if self.prob_ppl_weight != 0: + prob_ppl_loss = self.prob_ppl_weight * prob_ppl_loss + if self.reduction == 'sum': + prob_ppl_loss = prob_ppl_loss * sample_size + loss = loss + prob_ppl_loss + + if self.feature_loss_weight != 0: + feature_loss = self.feature_loss_weight * feature_loss + if self.reduction == 'sum': + feature_loss = feature_loss * sample_size + loss = loss + feature_loss + + accuracy = None + if compute_accuracy: + with torch.no_grad(): + if similarity_scores.numel() == 0: + corr = 0 + count = 0 + accuracy = float('nan') + else: + assert similarity_scores.dim() > 1, similarity_scores.shape + max = similarity_scores.argmax(-1) == 0 + min = similarity_scores.argmin(-1) == 0 + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = float(max.numel()) + accuracy = corr / count + + return loss, contrastive_loss, feature_loss, prob_ppl_loss, accuracy + + def _calculate_similarity(self, logits, negatives, targets): + neg_is_pos = (targets == negatives).all(-1) + targets = targets.unsqueeze(0) + targets = torch.cat([targets, negatives], dim=0) + logits = torch.cosine_similarity(logits.float(), targets.float(), dim=-1).type_as(logits) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + return logits diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3250071955216f6abc505e6181fb59931baa8d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/rnnt_wer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/rnnt_wer.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5c4c7b07cc79ec3d00084f2789f4da1a0d74a3 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/rnnt_wer.py @@ -0,0 +1,418 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List, Optional + +import editdistance +import torch +from pytorch_lightning.metrics import Metric + +from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts import transformert_greedy_decoding, transformert_beam_decoding +from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.utils import logging + +__all__ = ['RNNTDecoding', 'RNNTWER'] + + +class AbstractRNNTDecoding(ABC): + """ + Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state. + + Args: + decoding_cfg: A dict-like object which contains the following key-value pairs. + strategy: str value which represents the type of decoding that can occur. + Possible values are : + - greedy, greedy_batch (for greedy decoding). + - beam, tsd, alsd (for beam search decoding). + + The config may further contain the following sub-dictionaries: + "greedy": + max_symbols: int, describing the maximum number of target tokens to decode per + timestep during greedy decoding. Setting to larger values allows longer sentences + to be decoded, at the cost of increased execution time. + + "beam": + beam_size: int, defining the beam size for beam search. Must be >= 1. + If beam_size == 1, will perform cached greedy search. This might be slightly different + results compared to the greedy search above. + + score_norm: optional bool, whether to normalize the returned beam score in the hypotheses. + Set to True by default. + + return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the + hypotheses after beam search has concluded. This flag is set by default. + + tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols + per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, + at increased cost to execution time. + + alsd_max_target_len: optional int or float, determines the potential maximum target sequence length. + If an integer is provided, it can decode sequences of that particular maximum length. + If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), + where seq_len is the length of the acoustic model output (T). + + NOTE: + If a float is provided, it can be greater than 1! + By default, a float of 2.0 is used so that a target sequence can be at most twice + as long as the acoustic model output length T. + + decoder: The Decoder/Prediction network module. + joint: The Joint network module. + blank_id: The id of the RNNT blank token. + """ + + def __init__(self, decoding_cfg, decoder, joint, blank_id): + super(AbstractRNNTDecoding, self).__init__() + self.cfg = decoding_cfg + self.blank_id = blank_id + + possible_strategies = ['greedy', 'greedy_batch', 'beam', 'tsd', 'alsd', 'transformer_greedy', + 'transformer_beam'] + if self.cfg.strategy not in possible_strategies: + raise ValueError(f"Decoding strategy must be one of {possible_strategies}") + + if self.cfg.strategy == 'greedy': + self.decoding = greedy_decode.GreedyRNNTInfer( + decoder_model=decoder, + joint_model=joint, + blank_index=self.blank_id, + max_symbols_per_step=self.cfg.greedy.get('max_symbols', None), + ) + + elif self.cfg.strategy == 'greedy_batch': + self.decoding = greedy_decode.GreedyBatchedRNNTInfer( + decoder_model=decoder, + joint_model=joint, + blank_index=self.blank_id, + max_symbols_per_step=self.cfg.greedy.get('max_symbols', None), + ) + + elif self.cfg.strategy == 'beam': + + self.decoding = beam_decode.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='default', + score_norm=self.cfg.beam.get('score_norm', True), + ) + + elif self.cfg.strategy == 'tsd': + + self.decoding = beam_decode.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='tsd', + score_norm=self.cfg.beam.get('score_norm', True), + tsd_max_sym_exp_per_step=self.cfg.beam.get('tsd_max_sym_exp', 50), + ) + + elif self.cfg.strategy == 'alsd': + + self.decoding = beam_decode.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='alsd', + score_norm=self.cfg.beam.get('score_norm', True), + alsd_max_target_len=self.cfg.beam.get('alsd_max_target_len', 2), + ) + + elif self.cfg.strategy == 'transformer_greedy': + + self.decoding = transformert_greedy_decoding.GreedyTransformerTInfer( + decoder_model=decoder, + joint_model=joint, + blank_index=self.blank_id, + max_symbols_per_step=self.cfg.greedy.get('max_symbols', None), + ) + + elif self.cfg.strategy == 'transformer_beam': + + self.decoding = transformert_beam_decoding.BeamTransformerTInfer( + decoder_model=decoder, + joint_model=joint, + blank_index=self.blank_id, + **self.cfg.beam + ) + + def rnnt_decoder_predictions_tensor( + self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor + ) -> (List[str], Optional[List[List[str]]]): + """ + Decode an encoder output by autoregressive decoding of the Decoder+Joint networks. + + Args: + encoder_output: torch.Tensor of shape [B, D, T]. + encoded_lengths: torch.Tensor containing lengths of the padded encoder outputs. Shape [B]. + + Returns: + If `return_best_hypothesis` is set: + A tuple (hypotheses, None): + hypotheses - list of Hypothesis (best hypothesis per sample). + Look at rnnt_utils.Hypothesis for more information. + + If `return_best_hypothesis` is not set: + A tuple(hypotheses, all_hypotheses) + hypotheses - list of Hypothesis (best hypothesis per sample). + Look at rnnt_utils.Hypothesis for more information. + all_hypotheses - list of NBestHypotheses. Each NBestHypotheses further contains a sorted + list of all the hypotheses of the model per sample. + Look at rnnt_utils.NBestHypotheses for more information. + """ + # Compute hypotheses + with torch.no_grad(): + hypotheses_list = self.decoding( + encoder_output=encoder_output, encoded_lengths=encoded_lengths + ) # type: [List[Hypothesis]] + + # extract the hypotheses + hypotheses_list = hypotheses_list[0] # type: List[Hypothesis] + + prediction_list = hypotheses_list + + if isinstance(prediction_list[0], NBestHypotheses): + hypotheses = [] + all_hypotheses = [] + for nbest_hyp in prediction_list: # type: NBestHypotheses + n_hyps = nbest_hyp.n_best_hypotheses # Extract all hypotheses for this sample + decoded_hyps = self.decode_hypothesis(n_hyps) # type: List[str] + hypotheses.append(decoded_hyps[0]) # best hypothesis + all_hypotheses.append(decoded_hyps) + + return hypotheses, all_hypotheses + else: + hypotheses = self.decode_hypothesis(prediction_list) # type: List[str] + return hypotheses, None + + def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[str]: + """ + Decode a list of hypotheses into a list of strings. + + Args: + hypotheses_list: List of Hypothesis. + + Returns: + A list of strings. + """ + hypotheses = [] + for ind in range(len(hypotheses_list)): + # Extract the integer encoded hypothesis + prediction = hypotheses_list[ind].y_sequence + + if type(prediction) != list: + prediction = prediction.tolist() + + # RNN-T sample level is already preprocessed by implicit CTC decoding + # Simply remove any blank tokens + prediction = [p for p in prediction if p != self.blank_id] + + # De-tokenize the integer tokens + hypothesis = self.decode_tokens_to_str(prediction) + hypotheses.append(hypothesis) + + return hypotheses + + @abstractmethod + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented by subclass in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + raise NotImplementedError() + + +class RNNTDecoding(AbstractRNNTDecoding): + """ + Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state. + + Args: + decoding_cfg: A dict-like object which contains the following key-value pairs. + strategy: str value which represents the type of decoding that can occur. + Possible values are : + - greedy, greedy_batch (for greedy decoding). + - beam, tsd, alsd (for beam search decoding). + + The config may further contain the following sub-dictionaries: + "greedy": + max_symbols: int, describing the maximum number of target tokens to decode per + timestep during greedy decoding. Setting to larger values allows longer sentences + to be decoded, at the cost of increased execution time. + + "beam": + beam_size: int, defining the beam size for beam search. Must be >= 1. + If beam_size == 1, will perform cached greedy search. This might be slightly different + results compared to the greedy search above. + + score_norm: optional bool, whether to normalize the returned beam score in the hypotheses. + Set to True by default. + + return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the + hypotheses after beam search has concluded. This flag is set by default. + + tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols + per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, + at increased cost to execution time. + + alsd_max_target_len: optional int or float, determines the potential maximum target sequence length. + If an integer is provided, it can decode sequences of that particular maximum length. + If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), + where seq_len is the length of the acoustic model output (T). + + NOTE: + If a float is provided, it can be greater than 1! + By default, a float of 2.0 is used so that a target sequence can be at most twice + as long as the acoustic model output length T. + + decoder: The Decoder/Prediction network module. + joint: The Joint network module. + vocabulary: The vocabulary (excluding the RNNT blank token) which will be used for decoding. + """ + + def __init__( + self, decoding_cfg, decoder, joint, vocabulary, + ): + blank_id = decoder.blank_idx + self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) + + super(RNNTDecoding, self).__init__(decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id) + + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented by subclass in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + hypothesis = ''.join([self.labels_map[c] for c in tokens if c != self.blank_id]) + return hypothesis + + +class RNNTWER(Metric): + """ + This metric computes numerator and denominator for Overall Word Error Rate (WER) between prediction and reference texts. + When doing distributed training/evaluation the result of res=WER(predictions, targets, target_lengths) calls + will be all-reduced between all workers using SUM operations. + Here contains two numbers res=[wer_numerator, wer_denominator]. WER=wer_numerator/wer_denominator. + + If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step results. + Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. + + Example: + def validation_step(self, batch, batch_idx): + ... + wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len) + return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom} + + def validation_epoch_end(self, outputs): + ... + wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom} + return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + Args: + decoding: RNNTDecoding object that will perform autoregressive decoding of the RNNT model. + batch_dim_index: Index of the batch dimension. + use_cer: Whether to use Character Error Rate isntead of Word Error Rate. + log_prediction: Whether to log a single decoded sample per call. + + Returns: + res: a torch.Tensor object with two elements: [wer_numerator, wer_denominator]. To correctly compute average + text word error rate, compute wer=wer_numerator/wer_denominator + """ + + def __init__( + self, decoding: RNNTDecoding, batch_dim_index=0, use_cer=False, log_prediction=True, dist_sync_on_step=False + ): + super(RNNTWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + self.decoding = decoding + self.batch_dim_index = batch_dim_index + self.use_cer = use_cer + self.log_prediction = log_prediction + self.blank_id = self.decoding.blank_id + self.labels_map = self.decoding.labels_map + + self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + + def update( + self, + encoder_output: torch.Tensor, + encoded_lengths: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + decode_results: dict = None, + log_prediction: bool = True + ) -> torch.Tensor: + words = 0.0 + scores = 0.0 + references = [] + with torch.no_grad(): + # prediction_cpu_tensor = tensors[0].long().cpu() + targets_cpu_tensor = targets.long().cpu() + tgt_lenths_cpu_tensor = target_lengths.long().cpu() + + # iterate over batch + for ind in range(targets_cpu_tensor.shape[self.batch_dim_index]): + tgt_len = tgt_lenths_cpu_tensor[ind].item() + target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() + + reference = self.decoding.decode_tokens_to_str(target) + references.append(reference) + + hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths) + + if decode_results is not None: + decode_results['references'] = references + decode_results['hypotheses'] = hypotheses + + if self.log_prediction and log_prediction: + logging.info(f"\n") + logging.info(f"reference :{references[0]}") + logging.info(f"predicted :{hypotheses[0]}") + + for h, r in zip(hypotheses, references): + if self.use_cer: + h_list = list(h) + r_list = list(r) + else: + h_list = h.split() + r_list = r.split() + words += len(r_list) + # Compute Levenshtein's distance + scores += editdistance.eval(h_list, r_list) + + self.scores += torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype) + self.words += torch.tensor(words, device=self.words.device, dtype=self.words.dtype) + # return torch.tensor([scores, words]).to(predictions.device) + + def compute(self): + wer = self.scores.float() / self.words + return wer, self.scores.detach(), self.words.detach() diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/rnnt_wer_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..817db947c6a303ef7463058ed1b4c84a6a32570f --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -0,0 +1,207 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import editdistance +import torch +from pytorch_lightning.metrics import Metric + +from nemo.collections.asr.metrics.rnnt_wer import AbstractRNNTDecoding +from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.utils import logging + +__all__ = ['RNNTBPEDecoding', 'RNNTBPEWER'] + + +class RNNTBPEDecoding(AbstractRNNTDecoding): + """ + Used for performing RNN-T auto-regressive decoding of the Decoder+Joint network given the encoder state. + + Args: + decoding_cfg: A dict-like object which contains the following key-value pairs. + strategy: str value which represents the type of decoding that can occur. + Possible values are : + - greedy, greedy_batch (for greedy decoding). + - beam, tsd, alsd (for beam search decoding). + + The config may further contain the following sub-dictionaries: + "greedy": + max_symbols: int, describing the maximum number of target tokens to decode per + timestep during greedy decoding. Setting to larger values allows longer sentences + to be decoded, at the cost of increased execution time. + + "beam": + beam_size: int, defining the beam size for beam search. Must be >= 1. + If beam_size == 1, will perform cached greedy search. This might be slightly different + results compared to the greedy search above. + + score_norm: optional bool, whether to normalize the returned beam score in the hypotheses. + Set to True by default. + + return_best_hypothesis: optional bool, whether to return just the best hypothesis or all of the + hypotheses after beam search has concluded. + + tsd_max_sym_exp: optional int, determines number of symmetric expansions of the target symbols + per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, + at increased cost to execution time. + + alsd_max_target_len: optional int or float, determines the potential maximum target sequence length. + If an integer is provided, it can decode sequences of that particular maximum length. + If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), + where seq_len is the length of the acoustic model output (T). + + NOTE: + If a float is provided, it can be greater than 1! + By default, a float of 2.0 is used so that a target sequence can be at most twice + as long as the acoustic model output length T. + + decoder: The Decoder/Prediction network module. + joint: The Joint network module. + tokenizer: The tokenizer which will be used for decoding. + """ + + def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): + blank_id = decoder.blank_idx + self.tokenizer = tokenizer + + super(RNNTBPEDecoding, self).__init__( + decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id + ) + + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented by subclass in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + hypothesis = self.tokenizer.ids_to_text(tokens) + return hypothesis + + +class RNNTBPEWER(Metric): + """ + This metric computes numerator and denominator for Overall Word Error Rate (WER) between prediction and reference texts. + When doing distributed training/evaluation the result of res=WER(predictions, targets, target_lengths) calls + will be all-reduced between all workers using SUM operations. + Here contains two numbers res=[wer_numerator, wer_denominator]. WER=wer_numerator/wer_denominator. + + If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step results. + Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. + + Example: + def validation_step(self, batch, batch_idx): + ... + wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len) + return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom} + + def validation_epoch_end(self, outputs): + ... + wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom} + return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + Args: + decoding: RNNTBPEDecoding object that will perform autoregressive decoding of the RNNT model. + batch_dim_index: Index of the batch dimension. + use_cer: Whether to use Character Error Rate isntead of Word Error Rate. + log_prediction: Whether to log a single decoded sample per call. + + Returns: + res: a torch.Tensor object with two elements: [wer_numerator, wer_denominator]. To correctly compute average + text word error rate, compute wer=wer_numerator/wer_denominator + """ + + def __init__( + self, + decoding: RNNTBPEDecoding, + batch_dim_index=0, + use_cer: bool = False, + log_prediction: bool = True, + dist_sync_on_step=False, + ): + super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + self.decoding = decoding + self.batch_dim_index = batch_dim_index + self.use_cer = use_cer + self.log_prediction = log_prediction + self.blank_id = self.decoding.blank_id + self.tokenizer = self.decoding.tokenizer + + self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + + def update( + self, + encoder_output: torch.Tensor, + encoded_lengths: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + decode_results: dict = None, + log_prediction: bool = True + ) -> torch.Tensor: + words = 0.0 + scores = 0.0 + references = [] + with torch.no_grad(): + # prediction_cpu_tensor = tensors[0].long().cpu() + targets_cpu_tensor = targets.long().cpu() + tgt_lenths_cpu_tensor = target_lengths.long().cpu() + + # iterate over batch + for ind in range(targets_cpu_tensor.shape[self.batch_dim_index]): + tgt_len = tgt_lenths_cpu_tensor[ind].item() + target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() + reference = self.decoding.decode_tokens_to_str(target) + references.append(reference) + + hypotheses, _ = self.decoding.rnnt_decoder_predictions_tensor(encoder_output, encoded_lengths) + + if decode_results is not None: + decode_results['references'] = references + decode_results['hypotheses'] = hypotheses + + if self.log_prediction and log_prediction: + logging.info(f"\n") + logging.info(f"reference :{references[0]}") + logging.info(f"predicted :{hypotheses[0]}") + + for h, r in zip(hypotheses, references): + if self.use_cer: + h_list = list(h) + r_list = list(r) + else: + h_list = h.split() + r_list = r.split() + words += len(r_list) + # Compute Levenshtein's distance + scores += editdistance.eval(h_list, r_list) + + del hypotheses + + self.scores += torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype) + self.words += torch.tensor(words, device=self.words.device, dtype=self.words.dtype) + # return torch.tensor([scores, words]).to(predictions.device) + + def compute(self): + wer = self.scores.float() / self.words + return wer, self.scores.detach(), self.words.detach() diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc9ac1c0db830e7e03a40c550da0017c406a6b8 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer.py @@ -0,0 +1,286 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import editdistance +import torch +from pytorch_lightning.metrics import Metric + +from nemo.collections.asr.parts.rnnt_utils import Hypothesis +from nemo.utils import logging + +__all__ = ['word_error_rate', 'WER'] + + +def word_error_rate(hypotheses: List[str], references: List[str], use_cer=False) -> float: + """ + Computes Average Word Error rate between two texts represented as + corresponding lists of string. Hypotheses and references must have same + length. + Args: + hypotheses: list of hypotheses + references: list of references + use_cer: bool, set True to enable cer + Returns: + (float) average word error rate + """ + scores = 0 + words = 0 + if len(hypotheses) != len(references): + raise ValueError( + "In word error rate calculation, hypotheses and reference" + " lists must have the same number of elements. But I got:" + "{0} and {1} correspondingly".format(len(hypotheses), len(references)) + ) + for h, r in zip(hypotheses, references): + if use_cer: + h_list = list(h) + r_list = list(r) + else: + h_list = h.split() + r_list = r.split() + words += len(r_list) + scores += editdistance.eval(h_list, r_list) + if words != 0: + wer = 1.0 * scores / words + else: + wer = float('inf') + return wer + + +class WER(Metric): + """ + This metric computes numerator and denominator for Overall Word Error Rate (WER) between prediction and reference texts. + When doing distributed training/evaluation the result of res=WER(predictions, targets, target_lengths) calls + will be all-reduced between all workers using SUM operations. + Here contains two numbers res=[wer_numerator, wer_denominator]. WER=wer_numerator/wer_denominator. + + If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step results. + Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. + + Example: + def validation_step(self, batch, batch_idx): + ... + wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len) + return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom} + + def validation_epoch_end(self, outputs): + ... + wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom} + return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + Args: + vocabulary: List of strings that describes the vocabulary of the dataset. + batch_dim_index: Index of the batch dimension. + use_cer: Whether to use Character Error Rate instead of Word Error Rate. + ctc_decode: Whether to use CTC decoding or not. Currently, must be set. + log_prediction: Whether to log a single decoded sample per call. + + Returns: + res: a torch.Tensor object with two elements: [wer_numerator, wer_denominator]. To correctly compute average + text word error rate, compute wer=wer_numerator/wer_denominator + """ + + def __init__( + self, + vocabulary, + *, + blank_id, + batch_dim_index=0, + use_cer=False, + ctc_decode=True, + log_prediction=True, + dist_sync_on_step=False, + strip_end_space=False + ): + super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + self.batch_dim_index = batch_dim_index + self.blank_id = blank_id + self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) + self.use_cer = use_cer + self.ctc_decode = ctc_decode + self.log_prediction = log_prediction + self.strip_end_space = strip_end_space + if self.strip_end_space: + print('INFO: WER strip_end_space') + + self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + + def ctc_decoder_predictions_tensor( + self, predictions: torch.Tensor, predictions_len: torch.Tensor = None, return_hypotheses: bool = False, return_ctc_tokens: bool = False, + ) -> List[str]: + """ + Decodes a sequence of labels to words + + Args: + predictions: A torch.Tensor of shape [Batch, Time] of integer indices that correspond + to the index of some character in the label set. + predictions_len: Optional tensor of length `Batch` which contains the integer lengths + of the sequence in the padded `predictions` tensor. + return_hypotheses: Bool flag whether to return just the decoding predictions of the model + or a Hypothesis object that holds information such as the decoded `text`, + the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available). + May also contain the log-probabilities of the decoder (if this method is called via + transcribe()) + + Returns: + Either a list of str which represent the CTC decoded strings per sample, + or a list of Hypothesis objects containing additional information. + """ + hypotheses = [] + # Drop predictions to CPU + prediction_cpu_tensor = predictions.long().cpu() + # iterate over batch + for ind in range(prediction_cpu_tensor.shape[self.batch_dim_index]): + prediction = prediction_cpu_tensor[ind].detach().numpy().tolist() + if predictions_len is not None: + prediction = prediction[: predictions_len[ind]] + # CTC decoding procedure + decoded_prediction = [] + previous = self.blank_id + for p in prediction: + if (p != previous or previous == self.blank_id) and p != self.blank_id: + decoded_prediction.append(p) + previous = p + + text = self.decode_tokens_to_str(decoded_prediction) + + if not return_hypotheses: + # hypothesis = text + + # By Dehua + if not return_ctc_tokens: + hypothesis = text + else: + ctc_tokens = [self.labels_map[c] if c != self.blank_id else "BLANK" for c in prediction] + hypothesis = ' '.join(ctc_tokens) + else: + hypothesis = Hypothesis( + y_sequence=None, + score=-1.0, + text=text, + alignments=prediction, + length=predictions_len[ind] if predictions_len is not None else 0, + ) + + hypotheses.append(hypothesis) + return hypotheses + + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + hypothesis = ''.join(self.decode_ids_to_tokens(tokens)) + return hypothesis + + def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: + """ + Implemented in order to decode a token id list into a token list. + A token list is the string representation of each token id. + + Args: + tokens: List of int representing the token ids. + + Returns: + A list of decoded tokens. + """ + token_list = [self.labels_map[c] for c in tokens if c != self.blank_id] + return token_list + + def update( + self, + predictions: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + predictions_lengths: torch.Tensor = None, + log_prediction=False, + decode_results: dict = None, + ) -> torch.Tensor: + words = 0.0 + scores = 0.0 + references = [] + with torch.no_grad(): + # prediction_cpu_tensor = tensors[0].long().cpu() + targets_cpu_tensor = targets.long().cpu() + tgt_lenths_cpu_tensor = target_lengths.long().cpu() + + # iterate over batch + for ind in range(targets_cpu_tensor.shape[self.batch_dim_index]): + tgt_len = tgt_lenths_cpu_tensor[ind].item() + target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() + reference = self.decode_tokens_to_str(target) + references.append(reference) + if self.ctc_decode: + hypotheses = self.ctc_decoder_predictions_tensor(predictions, predictions_lengths) + else: + raise NotImplementedError("Implement me if you need non-CTC decode on predictions") + + if self.log_prediction or log_prediction: + logging.info(f"\n") + logging.info(f"reference:{references[0]}") + logging.info(f"predicted:{hypotheses[0]}") + + if self.strip_end_space: + # in my_preprocess, there may be space added at the end + references = [ref_i.rstrip(' ') for ref_i in references] + hypotheses = [hyp_i.rstrip(' ') for hyp_i in hypotheses] + + if decode_results is not None: + decode_results['references'] = references + decode_results['hypotheses'] = hypotheses + + for h, r in zip(hypotheses, references): + if self.use_cer: + h_list = list(h) + r_list = list(r) + else: + h_list = h.split() + r_list = r.split() + words += len(r_list) + # Compute Levenstein's distance + scores += editdistance.eval(h_list, r_list) + + self.scores = torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype) + self.words = torch.tensor(words, device=self.words.device, dtype=self.words.dtype) + # return torch.tensor([scores, words]).to(predictions.device) + + def compute(self): + scores = self.scores.detach().float() + words = self.words.detach().float() + return scores / words, scores, words + +class WER_phone(WER): + + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + hypothesis = ' '.join(self.decode_ids_to_tokens(tokens)) + return hypothesis \ No newline at end of file diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer_bak_orig_20240130.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer_bak_orig_20240130.py new file mode 100644 index 0000000000000000000000000000000000000000..1d58c9d7315f39c0a88ef36ee91cfba15d5b06a8 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer_bak_orig_20240130.py @@ -0,0 +1,283 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import editdistance +import torch +from pytorch_lightning.metrics import Metric + +from nemo.collections.asr.parts.rnnt_utils import Hypothesis +from nemo.utils import logging + +__all__ = ['word_error_rate', 'WER'] + + +def word_error_rate(hypotheses: List[str], references: List[str], use_cer=False) -> float: + """ + Computes Average Word Error rate between two texts represented as + corresponding lists of string. Hypotheses and references must have same + length. + Args: + hypotheses: list of hypotheses + references: list of references + use_cer: bool, set True to enable cer + Returns: + (float) average word error rate + """ + scores = 0 + words = 0 + if len(hypotheses) != len(references): + raise ValueError( + "In word error rate calculation, hypotheses and reference" + " lists must have the same number of elements. But I got:" + "{0} and {1} correspondingly".format(len(hypotheses), len(references)) + ) + for h, r in zip(hypotheses, references): + if use_cer: + h_list = list(h) + r_list = list(r) + else: + h_list = h.split() + r_list = r.split() + words += len(r_list) + scores += editdistance.eval(h_list, r_list) + if words != 0: + wer = 1.0 * scores / words + else: + wer = float('inf') + return wer + + +class WER(Metric): + """ + This metric computes numerator and denominator for Overall Word Error Rate (WER) between prediction and reference texts. + When doing distributed training/evaluation the result of res=WER(predictions, targets, target_lengths) calls + will be all-reduced between all workers using SUM operations. + Here contains two numbers res=[wer_numerator, wer_denominator]. WER=wer_numerator/wer_denominator. + + If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step results. + Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. + + Example: + def validation_step(self, batch, batch_idx): + ... + wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len) + return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom} + + def validation_epoch_end(self, outputs): + ... + wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom} + return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + Args: + vocabulary: List of strings that describes the vocabulary of the dataset. + batch_dim_index: Index of the batch dimension. + use_cer: Whether to use Character Error Rate instead of Word Error Rate. + ctc_decode: Whether to use CTC decoding or not. Currently, must be set. + log_prediction: Whether to log a single decoded sample per call. + + Returns: + res: a torch.Tensor object with two elements: [wer_numerator, wer_denominator]. To correctly compute average + text word error rate, compute wer=wer_numerator/wer_denominator + """ + + def __init__( + self, + vocabulary, + *, + blank_id, + batch_dim_index=0, + use_cer=False, + ctc_decode=True, + log_prediction=True, + dist_sync_on_step=False, + strip_end_space=False + ): + super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + self.batch_dim_index = batch_dim_index + self.blank_id = blank_id + self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) + self.use_cer = use_cer + self.ctc_decode = ctc_decode + self.log_prediction = log_prediction + self.strip_end_space = strip_end_space + if self.strip_end_space: + print('INFO: WER strip_end_space') + + self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + + def ctc_decoder_predictions_tensor( + self, predictions: torch.Tensor, predictions_len: torch.Tensor = None, return_hypotheses: bool = False, + ) -> List[str]: + """ + Decodes a sequence of labels to words + + Args: + predictions: A torch.Tensor of shape [Batch, Time] of integer indices that correspond + to the index of some character in the label set. + predictions_len: Optional tensor of length `Batch` which contains the integer lengths + of the sequence in the padded `predictions` tensor. + return_hypotheses: Bool flag whether to return just the decoding predictions of the model + or a Hypothesis object that holds information such as the decoded `text`, + the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available). + May also contain the log-probabilities of the decoder (if this method is called via + transcribe()) + + Returns: + Either a list of str which represent the CTC decoded strings per sample, + or a list of Hypothesis objects containing additional information. + """ + hypotheses = [] + # Drop predictions to CPU + prediction_cpu_tensor = predictions.long().cpu() + # iterate over batch + for ind in range(prediction_cpu_tensor.shape[self.batch_dim_index]): + prediction = prediction_cpu_tensor[ind].detach().numpy().tolist() + if predictions_len is not None: + prediction = prediction[: predictions_len[ind]] + # CTC decoding procedure + decoded_prediction = [] + previous = self.blank_id + for p in prediction: + if (p != previous or previous == self.blank_id) and p != self.blank_id: + decoded_prediction.append(p) + previous = p + + text = self.decode_tokens_to_str(decoded_prediction) + + if not return_hypotheses: + hypothesis = text + + # # By Dehua + # ctc_tokens = [self.labels_map[c] if c != self.blank_id else "BLANK" for c in prediction] + # hypothesis = ' '.join(ctc_tokens) + else: + hypothesis = Hypothesis( + y_sequence=None, + score=-1.0, + text=text, + alignments=prediction, + length=predictions_len[ind] if predictions_len is not None else 0, + ) + + hypotheses.append(hypothesis) + return hypotheses + + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + hypothesis = ''.join(self.decode_ids_to_tokens(tokens)) + return hypothesis + + def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: + """ + Implemented in order to decode a token id list into a token list. + A token list is the string representation of each token id. + + Args: + tokens: List of int representing the token ids. + + Returns: + A list of decoded tokens. + """ + token_list = [self.labels_map[c] for c in tokens if c != self.blank_id] + return token_list + + def update( + self, + predictions: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + predictions_lengths: torch.Tensor = None, + log_prediction=False, + decode_results: dict = None, + ) -> torch.Tensor: + words = 0.0 + scores = 0.0 + references = [] + with torch.no_grad(): + # prediction_cpu_tensor = tensors[0].long().cpu() + targets_cpu_tensor = targets.long().cpu() + tgt_lenths_cpu_tensor = target_lengths.long().cpu() + + # iterate over batch + for ind in range(targets_cpu_tensor.shape[self.batch_dim_index]): + tgt_len = tgt_lenths_cpu_tensor[ind].item() + target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() + reference = self.decode_tokens_to_str(target) + references.append(reference) + if self.ctc_decode: + hypotheses = self.ctc_decoder_predictions_tensor(predictions, predictions_lengths) + else: + raise NotImplementedError("Implement me if you need non-CTC decode on predictions") + + if self.log_prediction or log_prediction: + logging.info(f"\n") + logging.info(f"reference:{references[0]}") + logging.info(f"predicted:{hypotheses[0]}") + + if self.strip_end_space: + # in my_preprocess, there may be space added at the end + references = [ref_i.rstrip(' ') for ref_i in references] + hypotheses = [hyp_i.rstrip(' ') for hyp_i in hypotheses] + + if decode_results is not None: + decode_results['references'] = references + decode_results['hypotheses'] = hypotheses + + for h, r in zip(hypotheses, references): + if self.use_cer: + h_list = list(h) + r_list = list(r) + else: + h_list = h.split() + r_list = r.split() + words += len(r_list) + # Compute Levenstein's distance + scores += editdistance.eval(h_list, r_list) + + self.scores = torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype) + self.words = torch.tensor(words, device=self.words.device, dtype=self.words.dtype) + # return torch.tensor([scores, words]).to(predictions.device) + + def compute(self): + scores = self.scores.detach().float() + words = self.words.detach().float() + return scores / words, scores, words + +class WER_phone(WER): + + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + hypothesis = ' '.join(self.decode_ids_to_tokens(tokens)) + return hypothesis \ No newline at end of file diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer_bpe.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..669e1daeb47edeebcbd3d6e1cd09437f87585deb --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/metrics/wer_bpe.py @@ -0,0 +1,248 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import editdistance +import torch +from pytorch_lightning.metrics import Metric + +from nemo.collections.asr.parts.rnnt_utils import Hypothesis +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.utils import logging + + +class WERBPE(Metric): + """ + This metric computes numerator and denominator for Overall Word Error Rate for BPE tokens (WER-BPE) between prediction and reference texts. + When doing distributed training/evaluation the result of res=WERBPE(predictions, targets, target_lengths) calls + will be all-reduced between all workers using SUM operations. + Here contains two numbers res=[wer_numerator, wer_denominator]. WERBPE=wer_numerator/wer_denominator. + + If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step results. + Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. + + Example: + def validation_step(self, batch, batch_idx): + ... + wer_num, wer_denom = self.__wer(predictions, transcript, transcript_len) + return {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom} + + def validation_epoch_end(self, outputs): + ... + wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom} + return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + Args: + vocabulary: NeMo tokenizer object, which inherits from TokenizerSpec. + batch_dim_index: Index of the batch dimension. + use_cer: Whether to compute word-error-rate or character-error-rate. + ctc_decode: Whether to perform CTC decode. + log_prediction: Whether to log a single decoded sample per call. + + Returns: + res: a torch.Tensor object with two elements: [wer_numerator, wer_denominators]. To correctly compute average + text word error rate, compute wer=wer_numerator/wer_denominators + """ + + def __init__( + self, + tokenizer: TokenizerSpec, + *, + blank_id, + batch_dim_index=0, + use_cer=False, + ctc_decode=True, + log_prediction=True, + dist_sync_on_step=False, + lang='en', + ): + super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) + self.tokenizer = tokenizer + self.batch_dim_index = batch_dim_index + self.blank_id = blank_id + self.use_cer = use_cer + self.ctc_decode = ctc_decode + self.log_prediction = log_prediction + self.lang = lang + print('WER lang: ', self.lang) + + self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + + def ctc_decoder_predictions_tensor( + self, predictions: torch.Tensor, predictions_len: torch.Tensor = None, return_hypotheses: bool = False + ) -> List[str]: + """ + Decodes a sequence of labels to words + + Args: + predictions: A torch.Tensor of shape [Batch, Time] of integer indices that correspond + to the index of some character in the vocabulary of the tokenizer. + predictions_len: Optional tensor of length `Batch` which contains the integer lengths + of the sequence in the padded `predictions` tensor. + return_hypotheses: Bool flag whether to return just the decoding predictions of the model + or a Hypothesis object that holds information such as the decoded `text`, + the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available). + May also contain the log-probabilities of the decoder (if this method is called via + transcribe()) inside `y_sequence`, otherwise it is set None as it is a duplicate of + `alignments`. + + Returns: + Either a list of str which represent the CTC decoded strings per sample, + or a list of Hypothesis objects containing additional information. + """ + hypotheses = [] + # Drop predictions to CPU + prediction_cpu_tensor = predictions.long().cpu() + # iterate over batch + for ind in range(prediction_cpu_tensor.shape[self.batch_dim_index]): + prediction = prediction_cpu_tensor[ind].detach().numpy().tolist() + if predictions_len is not None: + prediction = prediction[: predictions_len[ind]] + # CTC decoding procedure + decoded_prediction = [] + previous = self.blank_id + for p in prediction: + if (p != previous or previous == self.blank_id) and p != self.blank_id: + decoded_prediction.append(p) + previous = p + + text = self.decode_tokens_to_str(decoded_prediction) + + if not return_hypotheses: + hypothesis = text + + # # By Dehua + # ctc_tokens = [] + # for _id in prediction: + # if _id == self.blank_id: + # ctc_tokens.append("CtcBlank") + # else: + # _token = self.tokenizer.ids_to_tokens([_id]) + # ctc_tokens += _token + # hypothesis = ' '.join(ctc_tokens) + else: + hypothesis = Hypothesis( + y_sequence=None, # logprob info added by transcribe method + score=-1.0, + text=text, + alignments=prediction, + length=predictions_len[ind] if predictions_len is not None else 0, + ) + hypotheses.append(hypothesis) + return hypotheses + + def decode_tokens_to_str(self, tokens: List[int]) -> str: + """ + Implemented in order to decoder a token list into a string. + + Args: + tokens: List of int representing the token ids. + + Returns: + A decoded string. + """ + hypothesis = self.tokenizer.ids_to_text(tokens) + return hypothesis + + def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: + """ + Implemented in order to decode a token id list into a token list. + A token list is the string representation of each token id. + + Args: + tokens: List of int representing the token ids. + + Returns: + A list of decoded tokens. + """ + token_list = self.tokenizer.ids_to_tokens(tokens) + return token_list + + def update( + self, + predictions: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + predictions_lengths: torch.Tensor = None, + log_prediction=False, + decode_results: dict = None, + ): + words = 0.0 + scores = 0.0 + references = [] + with torch.no_grad(): + # prediction_cpu_tensor = tensors[0].long().cpu() + targets_cpu_tensor = targets.long().cpu() + tgt_lenths_cpu_tensor = target_lengths.long().cpu() + + # iterate over batch + for ind in range(targets_cpu_tensor.shape[self.batch_dim_index]): + tgt_len = tgt_lenths_cpu_tensor[ind].item() + target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() + reference = self.decode_tokens_to_str(target) + references.append(reference) + if self.ctc_decode: + hypotheses = self.ctc_decoder_predictions_tensor(predictions, predictions_lengths) + else: + raise NotImplementedError("Implement me if you need non-CTC decode on predictions") + + if decode_results is not None: + decode_results['references'] = references + decode_results['hypotheses'] = hypotheses + + if self.log_prediction or log_prediction: + logging.info(f"\n") + logging.info(f"reference:{references[0]}") + logging.info(f"predicted:{hypotheses[0]}") + + for h, r in zip(hypotheses, references): + if self.use_cer: + h_list = list(h) + r_list = list(r) + else: + if self.lang == 'zh': + h = add_delim_space(h) + r = add_delim_space(r) + else: + assert self.lang == 'en' + + h_list = h.split() + r_list = r.split() + words += len(r_list) + # Compute Levenstein's distance + scores += editdistance.eval(h_list, r_list) + + self.scores = torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype) + self.words = torch.tensor(words, device=self.words.device, dtype=self.words.dtype) + # return torch.tensor([scores, words]).to(predictions.device) + + def compute(self): + scores = self.scores.detach().float() + words = self.words.detach().float() + return scores / words, scores, words + + +import re +PAT_ZH_SPACE = re.compile(r'(?<=[\u4e00-\u9fff])(?=[\u4e00-\u9fff])|(?<=[\u4e00-\u9fff])(?=[a-zA-Z])|(?<=[a-zA-Z])(?=[\u4e00-\u9fff])') +PAT_SPACE = re.compile(r'\s+') + + +def add_delim_space(text): + text = PAT_ZH_SPACE.sub(' ', text) + text = PAT_SPACE.sub(' ', text) + return text diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0adfef37cc4b57d76c0278a8fa48690b2c4440f --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.models.classification_models import EncDecClassificationModel +from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel, ExtractSpeakerEmbeddingsModel +from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel +from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/asr_model.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/asr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3dde23bd587de1374421d43074b0d132698c69d1 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/asr_model.py @@ -0,0 +1,91 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from typing import List, Optional, Union + +import torch +from omegaconf import OmegaConf + +from nemo.core.classes import ModelPT +from nemo.core.classes.exportable import Exportable +from nemo.utils import model_utils + +__all__ = ['ASRModel'] + + +class ASRModel(ModelPT, ABC): + @abstractmethod + def transcribe(self, paths2audio_files: List[str], batch_size: int = 4) -> List[str]: + """ + Takes paths to audio files and returns text transcription + Args: + paths2audio_files: paths to audio fragment to be transcribed + + Returns: + transcription texts + """ + pass + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'val_loss': val_loss_mean, 'val_wer': wer_num / wer_denom} + return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom} + return {'test_loss': val_loss_mean, 'log': tensorboard_logs} + + @classmethod + def list_available_models(cls) -> 'List[PretrainedModelInfo]': + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + Returns: + List of available pre-trained models. + """ + # recursively walk the subclasses to generate pretrained model info + list_of_models = model_utils.resolve_subclass_pretrained_model_info(cls) + return list_of_models + + +class ExportableEncDecModel(Exportable): + """ + Simple utiliy mix-in to export models that consist of encoder/decoder pair + plus pre/post processor, but have to be exported as encoder/decoder pair only + (covers most ASR classes) + """ + + @property + def input_module(self): + return self.encoder + + @property + def output_module(self): + return self.decoder + + def forward_for_export(self, input, length=None): + encoder_output = self.input_module(input, length) + if isinstance(encoder_output, tuple): + return self.output_module(encoder_output[0]) + else: + return self.output_module(encoder_output) + + def _prepare_for_export(self, **kwargs): + self.input_module._prepare_for_export(**kwargs) + self.output_module._prepare_for_export(**kwargs) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/classification_models.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/classification_models.py new file mode 100644 index 0000000000000000000000000000000000000000..81da896c149bc688508e1b45804618cd7b6fc423 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/classification_models.py @@ -0,0 +1,538 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import os +import tempfile +from typing import Dict, List, Optional, Union + +import onnx +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataSet +from nemo.collections.asr.data.audio_to_text import AudioLabelDataset +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.parts.features import WaveformFeaturizer +from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.common.losses import CrossEntropyLoss +from nemo.collections.common.metrics import TopKClassificationAccuracy +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.neural_types import * +from nemo.utils import logging +from nemo.utils.export_utils import attach_onnx_to_onnx + +__all__ = ['EncDecClassificationModel', 'MatchboxNet'] + + +class EncDecClassificationModel(ASRModel, Exportable): + """Encoder decoder CTC-based models.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self._update_decoder_config(self._cfg.decoder) + + self.preprocessor = EncDecClassificationModel.from_config_dict(self._cfg.preprocessor) + self.encoder = EncDecClassificationModel.from_config_dict(self._cfg.encoder) + self.decoder = EncDecClassificationModel.from_config_dict(self._cfg.decoder) + self.loss = CrossEntropyLoss() + if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: + self.spec_augmentation = EncDecClassificationModel.from_config_dict(self._cfg.spec_augment) + else: + self.spec_augmentation = None + if hasattr(self._cfg, 'crop_or_pad_augment') and self._cfg.crop_or_pad_augment is not None: + self.crop_or_pad = EncDecClassificationModel.from_config_dict(self._cfg.crop_or_pad_augment) + else: + self.crop_or_pad = None + + # Setup metric objects + self._accuracy = TopKClassificationAccuracy(dist_sync_on_step=True) + + @torch.no_grad() + def transcribe(self, paths2audio_files: List[str], batch_size: int = 4, logprobs=False) -> List[str]: + """ + Generate class labels for provided audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is approximately 1 second. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of class labels. + + Returns: + + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + # We will store transcriptions here + labels = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000.0, 'label': self.cfg.labels[0]} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in temporary_datalayer: + logits = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + labels.append(logits[idx]) + else: + labels_k = [] + top_ks = self._accuracy.top_k + for top_k_i in top_ks: + # replace top k value with current top k + self._accuracy.top_k = top_k_i + labels_k_i = self._accuracy.top_k_predicted_labels(logits) + labels_k.append(labels_k_i) + + # convenience: if only one top_k, pop out the nested list + if len(top_ks) == 1: + labels_k = labels_k[0] + + labels += labels_k + # reset top k to orignal value + self._accuracy.top_k = top_ks + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + logging.set_verbosity(logging_level) + return labels + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + if config.get('manifest_filepath') is None: + return + + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor']) + else: + augmentor = None + + featurizer = WaveformFeaturizer( + sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor + ) + + if 'vad_stream' in config and config['vad_stream']: + print("Perform streaming frame-level VAD") + dataset = AudioToSpeechLabelDataSet( + manifest_filepath=config['manifest_filepath'], + labels=config['labels'], + featurizer=featurizer, + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + trim=config.get('trim_silence', True), + load_audio=config.get('load_audio', True), + time_length=config.get('time_length', 0.31), + shift_length=config.get('shift_length', 0.01), + ) + batch_size = 1 + collate_func = dataset.vad_frame_seq_collate_fn + else: + dataset = AudioLabelDataset( + manifest_filepath=config['manifest_filepath'], + labels=config['labels'], + featurizer=featurizer, + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + trim=config.get('trim_silence', True), + load_audio=config.get('load_audio', True), + ) + batch_size = config['batch_size'] + collate_func = dataset.collate_fn + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=batch_size, + collate_fn=collate_func, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + @classmethod + def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + model = PretrainedModelInfo( + pretrained_model_name="MatchboxNet-3x1x64-v1", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/MatchboxNet-3x1x64-v1.nemo", + description="MatchboxNet model trained on Google Speech Commands dataset (v1, 30 classes) which obtains 97.32% accuracy on test set.", + ) + result.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="MatchboxNet-3x2x64-v1", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/MatchboxNet-3x2x64-v1.nemo", + description="MatchboxNet model trained on Google Speech Commands dataset (v1, 30 classes) which obtains 97.68% accuracy on test set.", + ) + result.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="MatchboxNet-3x1x64-v2", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/MatchboxNet-3x1x64-v2.nemo", + description="MatchboxNet model trained on Google Speech Commands dataset (v2, 35 classes) which obtains 97.12% accuracy on test set.", + ) + result.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="MatchboxNet-3x1x64-v2", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/MatchboxNet-3x1x64-v2.nemo", + description="MatchboxNet model trained on Google Speech Commands dataset (v2, 30 classes) which obtains 97.29% accuracy on test set.", + ) + result.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="MatchboxNet-3x1x64-v2-subset-task", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/MatchboxNet-3x1x64-v2-subset-task.nemo", + description="MatchboxNet model trained on Google Speech Commands dataset (v2, 10+2 classes) which obtains 98.2% accuracy on test set.", + ) + result.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="MatchboxNet-3x2x64-v2-subset-task", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/MatchboxNet-3x2x64-v2-subset-task.nemo", + description="MatchboxNet model trained on Google Speech Commands dataset (v2, 10+2 classes) which obtains 98.4% accuracy on test set.", + ) + result.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="MatchboxNet-VAD-3x2", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/MatchboxNet_VAD_3x2.nemo", + description="Voice Activity Detection MatchboxNet model trained on google speech command (v2) and freesound background data, which obtains 0.992 accuracy on testset from same source and 0.852 TPR for FPR=0.315 on testset (ALL) of AVA movie data", + ) + result.append(model) + return result + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + audio_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), audio_eltype), + "input_signal_length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return {"outputs": NeuralType(('B', 'D'), LogitsType())} + + @typecheck() + def forward(self, input_signal, input_signal_length): + processed_signal, processed_signal_len = self.preprocessor( + input_signal=input_signal, length=input_signal_length, + ) + # Crop or pad is always applied + if self.crop_or_pad is not None: + processed_signal, processed_signal_len = self.crop_or_pad( + input_signal=processed_signal, length=processed_signal_len + ) + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal) + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_len) + logits = self.decoder(encoder_output=encoded) + return logits + + # PTL-specific methods + def training_step(self, batch, batch_nb): + self.training_step_end() + audio_signal, audio_signal_len, labels, labels_len = batch + logits = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + loss_value = self.loss(logits=logits, labels=labels) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + } + + self._accuracy(logits=logits, labels=labels) + top_k = self._accuracy.compute() + for i, top_i in enumerate(top_k): + tensorboard_logs[f'training_batch_accuracy_top@{i}'] = top_i + + return {'loss': loss_value, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + audio_signal, audio_signal_len, labels, labels_len = batch + logits = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + loss_value = self.loss(logits=logits, labels=labels) + acc = self._accuracy(logits=logits, labels=labels) + correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k + return { + 'val_loss': loss_value, + 'val_correct_counts': correct_counts, + 'val_total_counts': total_counts, + 'val_acc': acc, + } + + def test_step(self, batch, batch_idx, dataloader_idx=0): + audio_signal, audio_signal_len, labels, labels_len = batch + logits = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + loss_value = self.loss(logits=logits, labels=labels) + acc = self._accuracy(logits=logits, labels=labels) + correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k + return { + 'test_loss': loss_value, + 'test_correct_counts': correct_counts, + 'test_total_counts': total_counts, + 'test_acc': acc, + } + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + correct_counts = torch.stack([x['val_correct_counts'] for x in outputs]).sum(axis=0) + total_counts = torch.stack([x['val_total_counts'] for x in outputs]).sum(axis=0) + + self._accuracy.correct_counts_k = correct_counts + self._accuracy.total_counts_k = total_counts + topk_scores = self._accuracy.compute() + + tensorboard_log = {'val_loss': val_loss_mean} + for top_k, score in zip(self._accuracy.top_k, topk_scores): + tensorboard_log['val_epoch_top@{}'.format(top_k)] = score + + return {'log': tensorboard_log} + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + correct_counts = torch.stack([x['test_correct_counts'].unsqueeze(0) for x in outputs]).sum(axis=0) + total_counts = torch.stack([x['test_total_counts'].unsqueeze(0) for x in outputs]).sum(axis=0) + + self._accuracy.correct_counts_k = correct_counts + self._accuracy.total_counts_k = total_counts + topk_scores = self._accuracy.compute() + + tensorboard_log = {'test_loss': test_loss_mean} + for top_k, score in zip(self._accuracy.top_k, topk_scores): + tensorboard_log['test_epoch_top@{}'.format(top_k)] = score + + return {'log': tensorboard_log} + + def change_labels(self, new_labels: List[str]): + """ + Changes labels used by the decoder model. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another dataset. + + If new_labels == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_labels: list with new labels. Must contain at least 2 elements. Typically, \ + this is set of labels for the dataset. + + Returns: None + + """ + if new_labels is not None and not isinstance(new_labels, ListConfig): + new_labels = ListConfig(new_labels) + + if self._cfg.labels == new_labels: + logging.warning( + f"Old labels ({self._cfg.labels}) and new labels ({new_labels}) match. Not changing anything" + ) + else: + if new_labels is None or len(new_labels) == 0: + raise ValueError(f'New labels must be non-empty list of labels. But I got: {new_labels}') + + # Update config + self._cfg.labels = new_labels + + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + self._update_decoder_config(new_decoder_config) + del self.decoder + self.decoder = EncDecClassificationModel.from_config_dict(new_decoder_config) + + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + if 'train_ds' in self._cfg and self._cfg.train_ds is not None: + self._cfg.train_ds.labels = new_labels + + if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None: + self._cfg.validation_ds.labels = new_labels + + if 'test_ds' in self._cfg and self._cfg.test_ds is not None: + self._cfg.test_ds.labels = new_labels + + logging.info(f"Changed decoder output to {self.decoder.num_classes} labels.") + + def _update_decoder_config(self, cfg): + """ + Update the number of classes in the decoder based on labels provided. + + Args: + cfg: The config of the decoder which will be updated. + """ + OmegaConf.set_struct(cfg, False) + + labels = self.cfg.labels + + if 'params' in cfg: + cfg.params.num_classes = len(labels) + else: + cfg.num_classes = len(labels) + + OmegaConf.set_struct(cfg, True) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'labels': self.cfg.labels, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': False, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer + + def export( + self, + output: str, + input_example=None, + output_example=None, + verbose=False, + export_params=True, + do_constant_folding=True, + keep_initializers_as_inputs=False, + onnx_opset_version: int = 12, + try_script: bool = False, + set_eval: bool = True, + check_trace: bool = True, + use_dynamic_axes: bool = True, + ): + if input_example is not None or output_example is not None: + logging.warning( + "Passed input and output examples will be ignored and recomputed since" + " EncDecClassificationModel consists of two separate models (encoder and decoder) with different" + " inputs and outputs." + ) + + qual_name = self.__module__ + '.' + self.__class__.__qualname__ + output1 = os.path.join(os.path.dirname(output), 'encoder_' + os.path.basename(output)) + output1_descr = qual_name + ' Encoder exported to ONNX' + encoder_onnx = self.encoder.export( + output1, + None, # computed by input_example() + None, + verbose, + export_params, + do_constant_folding, + keep_initializers_as_inputs, + onnx_opset_version, + try_script, + set_eval, + check_trace, + use_dynamic_axes, + ) + + output2 = os.path.join(os.path.dirname(output), 'decoder_' + os.path.basename(output)) + output2_descr = qual_name + ' Decoder exported to ONNX' + decoder_onnx = self.decoder.export( + output2, + None, # computed by input_example() + None, + verbose, + export_params, + do_constant_folding, + keep_initializers_as_inputs, + onnx_opset_version, + try_script, + set_eval, + check_trace, + use_dynamic_axes, + ) + + output_model = attach_onnx_to_onnx(encoder_onnx, decoder_onnx, "EDC") + output_descr = qual_name + ' Encoder+Decoder exported to ONNX' + onnx.save(output_model, output) + return ([output, output1, output2], [output_descr, output1_descr, output2_descr]) + + +class MatchboxNet(EncDecClassificationModel): + pass diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..076c352663ee3b28d748cc15250b3051806e455b --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.models.configs.ctc_models_config import ( + EncDecCTCConfig, + EncDecCTCModelConfig, +) +from nemo.collections.asr.modules.audio_preprocessing import ( + AudioToMelSpectrogramPreprocessorConfig, + SpectrogramAugmentationConfig, +) +from nemo.collections.asr.modules.conv_asr import ConvASREncoderConfig, JasperEncoderConfig diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/common_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/common_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d04d072bf96caa5fba775198a426302b34315af3 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/common_config.py @@ -0,0 +1,223 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from omegaconf import MISSING + +import nemo.core.classes.dataset +from nemo.core.config.modelPT import SchedConfig, OptimConfig + +__all__ = ['Conv2dNormAct', 'Conv1dNormAct', 'Conv2dBlock', 'DatasetConfig', 'OptimConfig', 'NovogradParams', + 'WarmupParams', 'WarmupHoldParams', 'PolynomialDecayAnnealingParams', 'PolynomialHoldDecayAnnealingParams', + 'Tokenizer'] + + +@dataclass +class Conv2dNormAct: + filters: int = MISSING + kernel_size: Tuple[int, int] = MISSING + stride: Tuple[int, int] = MISSING + norm_type: Optional[str] = MISSING + gn_groups: Optional[int] = None + act_func: Optional[str] = MISSING + dilation: Tuple[int, int] = (1, 1) + dropout: float = 0.0 + padding: str = 'same' + bias: Optional[bool] = None + residual: bool = False + + +@dataclass +class Conv1dNormAct: + filters: int = MISSING + kernel_size: Tuple[int] = MISSING + stride: Tuple[int] = MISSING + norm_type: Optional[str] = MISSING + gn_groups: Optional[int] = None + act_func: Optional[str] = MISSING + dilation: Tuple[int] = (1,) + dropout: float = 0.0 + padding: str = 'same' + bias: Optional[bool] = None + residual: bool = False + + +@dataclass +class ConvFFN: + filters: int = MISSING + kernel_size: int = MISSING + act_func: Optional[str] = MISSING + norm_type: str = 'post_ln' + dropout: float = 0.0 + padding: str = 'same' + # use_tf_pad: bool = True + # ln_eps: float = 1e-5 + + +@dataclass +class ConvBlock: + feat_in: int = MISSING + use_conv_mask: bool = MISSING + conv_layers: List[Conv1dNormAct] = MISSING + conv_ffn_layers: Optional[List[ConvFFN]] = None + output_proj_dim: Optional[int] = None + use_tf_pad: bool = True + ln_eps: float = 1e-5 + + +@dataclass +class ProjUpsampling: + rate: int = MISSING + filters: int = MISSING + kernel_size: Tuple[int] = MISSING + norm_type: Optional[str] = MISSING + act_func: Optional[str] = MISSING + dropout: float = 0.0 + padding: str = 'same' + bias: bool = True + + +@dataclass +class Conv2dBlock: + layers: List[Conv2dNormAct] = MISSING + output_dim: int = MISSING + + +@dataclass +class DatasetConfig(nemo.core.classes.dataset.DatasetConfig): + manifest_dir: str = MISSING + data_dir: str = MISSING + manifest_filepath: str = MISSING + sample_rate: int = MISSING + labels: List[str] = MISSING + trim_silence: bool = False + + # Optional + int_values: bool = False + augmentor: Optional[Dict[str, Any]] = None + max_duration: Optional[float] = None + min_duration: Optional[float] = None + max_utts: int = 0 + dup_factor: int = 1 + blank_index: int = -1 + unk_index: int = -1 + normalize: bool = False + trim: bool = True + load_audio: bool = True + parser: Optional[str] = 'en' + parser_add_end_space: bool = False + add_misc: bool = False + subword_sampling_nbest_size: Optional[int] = None + subword_sampling_alpha: Optional[float] = None + + +@dataclass +class AudioDatasetConfig(nemo.core.classes.dataset.DatasetConfig): + manifest_dir: str = MISSING + data_dir: str = MISSING + manifest_filepath: str = MISSING + + sample_rate: int = MISSING + + max_duration: Optional[float] = None + min_duration: Optional[float] = None + crop_size: Optional[int] = None + + +@dataclass +class AdamWParams(OptimConfig): + name: str = 'adamw' + + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-08 + weight_decay: float = 0 + amsgrad: bool = False + + + +@dataclass +class AdamParams(OptimConfig): + name: str = 'adam' + + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-08 + weight_decay: float = 0 + amsgrad: bool = False + + +@dataclass +class NovogradParams(OptimConfig): + name: str = 'novograd' + + betas: Tuple[float, float] = (0.95, 0.98) + eps: float = 1e-8 + eps_in_sqrt: bool = False + weight_decay: float = 0 + weight_decay_ema: bool = True + grad_averaging: bool = False + amsgrad: bool = False + luc: bool = False + luc_grad_trust: float = 0.0 + luc_grad_trust_rel: bool = False + luc_trust: float = 1e-3 + luc_trust_min: float = 0.0 + luc_eps: float = 1e-8 + luc_update_min: float = 1e-7 + luc_update_max: float = 1.0 + + +@dataclass +class WarmupParams: + warmup_steps: Optional[int] = None + warmup_ratio: Optional[float] = None + warmup_power: Optional[float] = None + + +@dataclass +class WarmupHoldParams: + hold_steps: Optional[int] = None + hold_ratio: Optional[float] = None + + +@dataclass +class PolynomialDecayAnnealingParams(SchedConfig, WarmupParams): + name: str = 'PolynomialDecayAnnealing' + max_steps: int = MISSING + power: float = 1.0 + cycle: bool = False + + +@dataclass +class PolynomialHoldDecayAnnealingParams(PolynomialDecayAnnealingParams, WarmupHoldParams): + name: str = 'PolynomialHoldDecayAnnealing' + + +@dataclass +class CosineAnnealingParams(SchedConfig, WarmupParams): + name: str = 'CosineAnnealing' + max_steps: int = MISSING + + +@dataclass +class Tokenizer: + dir: str = MISSING # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + file: Optional[str] = None + type: str = 'bpe' + prepend_unk_to_vocab: bool = True + + +@dataclass +class G2PTableTokenizer(Tokenizer): + type: str = 'g2p_table' + prepend_unk_to_vocab: bool = False + lexicon_file: str = MISSING + + +@dataclass +class TextEmbed: + embed_num: int = MISSING + embed_size: int = MISSING + embed_proj_size: int = MISSING + embed_dropout: float = MISSING + norm_embed: bool = False + mask_label_prob: float = 0.0 + mask_label_id: Optional[int] = None diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/conv_transformer_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/conv_transformer_config.py new file mode 100644 index 0000000000000000000000000000000000000000..65e89d621fc073a6362bc6c576908615ad4ea325 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/conv_transformer_config.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from typing import List, Optional + +from omegaconf import MISSING + +from nemo.collections.asr.models.configs.common_config import Conv2dBlock, Conv1dNormAct + +__all__ = ['AdaptiveFFN', 'RelTransformerBlock', 'ConvTransformerBlock', 'ConvTransformerEncoder'] + + +@dataclass +class AdaptiveFFN: + gate_mix_prob: float = MISSING + gate_sample_prob: float = MISSING + identity_threshold: float = MISSING + identity_loss_weight: float = MISSING + init_identiy_bias: float = 0.0 + ffn_residual: bool = True + norm_in_ffn: bool = True + + +@dataclass +class RelTransformerBlock: + n_layer: int = MISSING + d_model: int = MISSING + n_head: int = MISSING + d_head: int = MISSING + d_inner: int = MISSING + dropout: float = MISSING + dropout_att: float = MISSING + pre_lnorm: bool = False + norm_output: bool = False + uni_attn: bool = True + norm_type: str = 'ln' + pos_enc: Optional[str] = 'xl' + layer_drop: float = 0.0 + + adaptive_ffn: Optional[AdaptiveFFN] = None + + +@dataclass +class ConvTransformerBlock: + conv_layers: List[Conv1dNormAct] + transformer_block: Optional[RelTransformerBlock] = MISSING + + +@dataclass +class ConvTransformerEncoder: + _target_: str = 'nemo.collections.asr.modules.ConvTransformerEncoder' + feat_in: int = MISSING + use_conv_mask: bool = MISSING + conv2d_block: Optional[Conv2dBlock] = MISSING + conv_transformer_blocks: List[ConvTransformerBlock] = MISSING + use_tf_pad: bool = True + ln_eps: float = 1e-5 + init_mode: str = 'xavier_uniform' + bias_init_mode: str = 'zero' diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/convtt_models_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/convtt_models_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e54a9a61f7800ba1e9f69bf477626891606afd --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/convtt_models_config.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from typing import List, Optional + +from omegaconf import MISSING + +from nemo.collections.asr.modules.audio_preprocessing import ( + AudioToMelSpectrogramPreprocessorConfig, + SpectrogramAugmentationConfig, +) +from nemo.core.config import modelPT as model_cfg + +from .common_config import * +from .transducer_config import * +from .conv_transformer_config import * + + +@dataclass +class ConvTTModel(model_cfg.ModelConfig): + labels: List[str] = MISSING + tokenizer: Optional[Tokenizer] = None + + train_ds: DatasetConfig = MISSING + validation_ds: DatasetConfig = MISSING + test_ds: DatasetConfig = MISSING + + expected_gpu_num: int = MISSING + optim: Optional[OptimConfig] = MISSING + + preprocessor: AudioToMelSpectrogramPreprocessorConfig = MISSING + spec_augment: Optional[SpectrogramAugmentationConfig] = MISSING + + encoder: ConvTransformerEncoder = MISSING + decoder: TransformerTDecoder = MISSING + joint: RNNTJoint = MISSING + model_defaults: ModelDefaults = MISSING + + decoding: Decoding = MISSING + + +@dataclass +class ConvTTConfig(model_cfg.ModelPTConfig): + model: ConvTTModel = ConvTTModel() diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/ctc_models_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/ctc_models_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b18a9b5f4636bbca382a0c97ba51fe2aeb31465a --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/ctc_models_config.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, List, Optional + +from omegaconf import MISSING + +from nemo.collections.asr.models.configs.common_config import Tokenizer, DatasetConfig, OptimConfig, Conv1dNormAct, \ + ProjUpsampling +from nemo.collections.asr.models.spec2vec.spec2vec_config import ProjectorConfig +from nemo.collections.asr.modules.audio_preprocessing import ( + AudioToMelSpectrogramPreprocessorConfig, + SpectrogramAugmentationConfig, +) +from nemo.core.config import modelPT as model_cfg + + +@dataclass +class ConvASRDecoderConfig: + _target_: str = 'nemo.collections.asr.modules.ConvASRDecoder' + feat_in: int = MISSING + num_classes: int = 0 + proj_upsampling: Optional[ProjUpsampling] = None + conv_layers: Optional[List[Conv1dNormAct]] = None + projector: Optional[ProjectorConfig] = None + use_conv_mask: bool = True + use_tf_pad: bool = True + ln_eps: float = 1e-5 + blank_pos: str = 'after_vocab_last' + init_mode: str = 'xavier_uniform' + vocabulary: Optional[List[str]] = field(default_factory=list) + + +@dataclass +class EncDecCTCConfig(model_cfg.ModelConfig): + # Model global arguments + sample_rate: int = 16000 + labels: List[str] = MISSING + tokenizer: Optional[Tokenizer] = None + + # Dataset configs + train_ds: DatasetConfig = MISSING + validation_ds: DatasetConfig = MISSING + test_ds: DatasetConfig = MISSING + + # Optimizer / Scheduler config + optim: OptimConfig = MISSING + + # Model component configs + preprocessor: AudioToMelSpectrogramPreprocessorConfig = MISSING + spec_augment: Optional[SpectrogramAugmentationConfig] = MISSING + encoder: Any = MISSING + decoder: Any = MISSING + + +@dataclass +class EncDecCTCModelConfig(model_cfg.ModelPTConfig): + model: EncDecCTCConfig = EncDecCTCConfig() diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/diffusion_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/diffusion_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1d92ee07025fedce53bdeaf2dfd462c29f9a7d4e --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/diffusion_config.py @@ -0,0 +1,108 @@ +from typing import List, Optional + +from dataclasses import dataclass + + +@dataclass +class FrequencyEmbedding: + num_freqs: int + max_freq_log2: float + min_freq_log2: float = 0. + log_sampling: bool = True + include_input: bool = True + + +@dataclass +class PreBlock: + input_dim: int + ffn_hidden_dim: int + model_dim: int + out_dim: int + time_embed_dim: int + num_ffn_layers: int + embed_args: FrequencyEmbedding + + +@dataclass +class PostBlock: + input_dim: int + ffn_hidden_dim: int + model_dim: int + time_embed_dim: int + num_ffn_layers: int + context_dim: int + context_group_dim: int + + +@dataclass +class SeqUNet: + in_channels: int + model_channels: int + out_channels: int + num_res_blocks: int + attention_resolutions: List[int] + dropout: float = 0.0 + channel_mult: List[int] = (1, 2, 4, 8) + conv_resample: bool = True + use_checkpoint: bool = False + use_fp16: bool = False + num_heads: int = -1 + num_head_channels: int = -1 + num_heads_upsample: int = -1 + use_scale_shift_norm: bool = False + resblock_updown: bool = False + use_new_attention_order: bool = False + use_spatial_transformer: bool = False # custom transformer support + transformer_depth: int = 1 # custom transformer support + context_dim: Optional[int] = None # custom transformer support + legacy: bool = True + norm_type: str = 'gn' + attention_norm_type: str = 'gn' + gn_groups: int = 32 + ff_mult: int = 4 + gated_ff: bool = True + trfm_proj_input: bool = True + timestep_transform: Optional[str] = None + pre_block_config: Optional[PreBlock] = None + post_block_config: Optional[PostBlock] = None + condition_mode: Optional[str] = None + + +@dataclass +class WaveGrad: + signal_size: int + cond_size: int + hidden_size: int + + +@dataclass +class Diffusion: + conditioning_key: Optional[str] = None + condition_dropout: float = 0.0 + upsampling_mode: Optional[str] = None + scale_factor: float = 1.0 + scale_by_std: bool = False + beta_schedule: str = "linear" + linear_start: float = 1e-4 + linear_end: float = 2e-2 + loss_type: str = "l2" + parameterization: str = "eps" + + +@dataclass +class KarrasDiffusion: + sigma_data: float = 1. + sigma_min: float = 1e-2 + sigma_max: float = 80.0 + sigma_dist_type: str = 'lognormal' + sigma_dist_mean: float = -1.2 + sigma_dist_std: float = 1.2 + parameterization: str = 'hybrid' + loss_type: str = "l2" + + +@dataclass +class Sampler: + type: str = 'plms' + steps: int = 10 + guidance_scale: float = 1.0 diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/flap_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/flap_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a7006a14560ee7a80db761b8c1a3c3530dec4cb6 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/flap_config.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from typing import List, Optional, Any + +from omegaconf import MISSING + +import nemo +from nemo.collections.asr.models.configs.common_config import TextEmbed, ConvBlock +from nemo.collections.asr.models.configs.diffusion_config import SeqUNet, Diffusion, Sampler +from nemo.collections.asr.models.wav2vec.wav2vec_config import Wav2VecMaskingConfig +from nemo.collections.asr.modules.audio_preprocessing import ( + AudioToMelSpectrogramPreprocessorConfig, + SpectrogramAugmentationConfig, +) +from nemo.core.config import modelPT as model_cfg + +from .common_config import * +from .transducer_config import * +from .conv_transformer_config import * + + +@dataclass +class Text2TextDatasetConfig(nemo.core.classes.dataset.DatasetConfig): + manifest_dir: str = MISSING + manifest_filepath: str = MISSING + + tokenizer: str = MISSING + + parse_online: bool = False + mask_id: int = MISSING + mask_prob: float = MISSING + word_mask: bool = False + space_id: int = MISSING + space_num_min: int = MISSING + space_num_max: int = MISSING + in_word_space_num_min: int = 0 + in_word_space_num_max: int = 0 + replace_prob: float = 0.0 + replace_ids: Optional[List[int]] = None + min_text_len: int = 1 + max_text_len: int = 10240 + max_crop_words: int = 0 + max_crop_chars: int = 0 + file_format: str = 'txt' + encoding: str = 'utf-8' + + +@dataclass +class UpsamplingConfig: + variance: float = 1.0 + +@dataclass +class TextConvTransformerEncoder: + _target_: str = 'nemo.collections.asr.modules.TextConvTransformerEncoder' + text_embed: TextEmbed = MISSING + encoder: ConvTransformerBlock = MISSING + use_conv_mask: bool = MISSING + use_tf_pad: bool = True + ln_eps: float = 1e-5 + init_mode: str = 'xavier_uniform' + bias_init_mode: str = 'zero' + embedding_init_mode: str = 'xavier_uniform' + pad_to: int = 0 + pad_value: float = 0.0 + upsampling_args: Optional[UpsamplingConfig] = None + + +@dataclass +class LatentAutoencoder: + encoder: ConvBlock = MISSING + decoder: ConvBlock = MISSING + loss_type: str = 'l2' + + +@dataclass +class TextUpsampling: + min_ratio: float = MISSING + max_ratio: float = MISSING + max_len: int = 0 + + +@dataclass +class DiffusionTextProjector: + dpm_type: str = 'ddpm' + diffusion: Any = MISSING + conditioning_key: Optional[str] = None + condition_dropout: float = 0.0 + latent_denoiser_type: str = 'unet' + latent_denoiser: Any = MISSING + text_encoder: TextConvTransformerEncoder = MISSING + sampler: Sampler = Sampler() + text_upsampling: Optional[TextUpsampling] = None + latent_autoencoder: Optional[LatentAutoencoder] = None + + +@dataclass +class FLAPTextProjectorTransducerModel(model_cfg.ModelConfig): + labels: List[str] = MISSING + tokenizer: Optional[Tokenizer] = None + input_tokenizer: Optional[Tokenizer] = None + + train_ds: Text2TextDatasetConfig = MISSING + validation_ds: Text2TextDatasetConfig = MISSING + test_ds: Text2TextDatasetConfig = MISSING + + expected_gpu_num: int = MISSING + optim: Optional[OptimConfig] = MISSING + + text_projector: TextConvTransformerEncoder = MISSING + + pretrained_chkpt: str = MISSING + pretrained_chkpt_converter: str = MISSING + + speech_pre_encoder: ConvTransformerEncoder = MISSING + speech_backbone_encoder: ConvTransformerEncoder = MISSING + decoder: TransformerTDecoder = MISSING + joint: RNNTJoint = MISSING + model_defaults: ModelDefaults = MISSING + + decoding: Decoding = MISSING + + +@dataclass +class FLAPTextProjectorTransducerConfig(model_cfg.ModelPTConfig): + model: FLAPTextProjectorTransducerModel = FLAPTextProjectorTransducerModel() + + +@dataclass +class FLAPLatentMaskConfig: + mask_emb_dim: int = MISSING + text_masking: Optional[Wav2VecMaskingConfig] = None + speech_masking: Optional[Wav2VecMaskingConfig] = None + + +@dataclass +class FLAPMultiTaskTransducerModel(model_cfg.ModelConfig): + labels: List[str] = MISSING + tokenizer: Optional[Tokenizer] = None + input_tokenizer: Optional[Tokenizer] = None + + speech2text_train_ds: DatasetConfig = MISSING + speech2text_validation_ds: DatasetConfig = MISSING + speech2text_test_ds: DatasetConfig = MISSING + + text2text_train_ds: Text2TextDatasetConfig = MISSING + text2text_validation_ds: Text2TextDatasetConfig = MISSING + text2text_test_ds: Text2TextDatasetConfig = MISSING + + pretrained_chkpt: Optional[str] = None + + expected_gpu_num: int = MISSING + optim: Optional[OptimConfig] = MISSING + + preprocessor: AudioToMelSpectrogramPreprocessorConfig = MISSING + spec_augment: Optional[SpectrogramAugmentationConfig] = MISSING + latent_masking: Optional[FLAPLatentMaskConfig] = None + + freeze_speech_pre_encoder: bool = True + freeze_text_projector: bool = True + + freeze_speech_backbone_transducer: bool = False + freeze_auxiliary_decoder: bool = False + + text_projector: Any = MISSING + speech_pre_encoder: ConvTransformerEncoder = MISSING + speech_backbone_encoder: ConvTransformerEncoder = MISSING + decoder: TransformerTDecoder = MISSING + joint: RNNTJoint = MISSING + model_defaults: ModelDefaults = MISSING + + auxiliary_decoder: Any = None + freeze_text_auxiliary_decoder: bool = True + + speech2text_loss_weight: float = MISSING + text2text_loss_weight: float = MISSING + text2text_grad_accum_batches: int = 1 + auxiliary_loss_weight: float = 0.0 + + decoding: Decoding = MISSING + + +@dataclass +class FLAPMultiTaskTransducerConfig(model_cfg.ModelPTConfig): + model: FLAPMultiTaskTransducerModel = FLAPMultiTaskTransducerModel() diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/spiral2_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/spiral2_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6bda2207dd9a042496152bf79783d0660b4c7a --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/spiral2_config.py @@ -0,0 +1,98 @@ +from typing import Optional, List, Any + +from dataclasses import field, dataclass +from omegaconf import MISSING + +from nemo.collections.asr.models.configs.common_config import ConvBlock +from nemo.collections.asr.models.spec2vec.spec2vec_config import FeatureEncoderConfig, ProjectorConfig, \ + NoisePerturbConfig +from nemo.collections.asr.models.st2vec.st2vec_config import ShiftPerturbConfig +from nemo.collections.asr.models.wav2vec.wav2vec_config import LossConfig, Wav2VecTransformerConfig, \ + Wav2VecMaskingConfig, QuantizerConfig +from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessorConfig +from nemo.core.config.modelPT import ModelConfig + + +@dataclass +class StyleFusionConfig: + # content_dim: int = MISSING + # style_dim: int = MISSING + content_upsampling_method: str = 'conv_proj' + content_upsampling_layers: List[Any] = None + fusion_method: str = 'sum' + + +@dataclass +class DecoderConfig: + content_dim: int = MISSING + style_dim: int = MISSING + style_fusion: StyleFusionConfig = MISSING + conv_block: ConvBlock = MISSING + content_quantizer: Optional[QuantizerConfig] = None + style_quantizer: Optional[QuantizerConfig] = None + use_tf_pad: bool = False + ln_eps: float = 1e-5 + + +@dataclass +class SPIRALConfig: + target_shifting: Optional[ShiftPerturbConfig] = None + target_compute_perturb: bool = False + + target_momentum: float = 0.99 + target_momentum_final: Optional[float] = None + target_momentum_steps: Optional[int] = None + target_momentum_type: Optional[str] = None + + projector: Optional[ProjectorConfig] = None + predictor: Optional[ProjectorConfig] = None + + n_negatives: int = field( + default=100, metadata={'help': 'Number of negatives to sample from the same audio sample'} + ) + cross_sample_negatives: int = field( + default=0, metadata={'help': 'Number of negatives to sample from any sample in the batch'} + ) + codebook_negatives: int = field(default=0, metadata={'help': 'Number of negative examples in codebook'}) + negatives_from_everywhere: bool = field( + default=False, metadata={'help': 'Sample negatives from everywhere, not just masked states'} + ) + negatives_from_noisy_features: bool = False + + +@dataclass +class ConvStyleEncoder(ConvBlock): + _target_: str = 'nemo.collections.asr.parts.convolution_layers.ConvBlock' + + +@dataclass +class SPIRAL2PretrainConfig(ModelConfig): + preprocessor: AudioToMelSpectrogramPreprocessorConfig = MISSING + + content_encoder: FeatureEncoderConfig = FeatureEncoderConfig() + # pretrained_encoder_path: Optional[str] = None + decoder: Optional[DecoderConfig] = None + freeze_feature_encoder: bool = False + noise_mix_ratio: Optional[float] = None + + content_masking: Optional[Wav2VecMaskingConfig] = None + + style_encoder: Optional[Any] = None + + style_masking: Optional[Wav2VecMaskingConfig] = None + + noise_perturb: Optional[NoisePerturbConfig] = None + + recon_loss_type: str = 'l2' + recon_time_mask_only: bool = False + + content_quant_ppl_loss_weight: float = 0.0 + style_quant_ppl_loss_weight: float = 0.0 + + spiral_config: Optional[SPIRALConfig] = None + spiral_loss_weight: float = 0.0 + spiral_loss_type: str = 'wav2vec' + spiral_logit_temp: float = field(default=0.1, metadata={'help': 'Temperature to divide logits by'}) + # loss: LossConfig = LossConfig() + + expected_gpu_num: int = 1 diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/transducer_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/transducer_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d2092505c9affb0187fc23ae1465b4535f5a22dd --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/configs/transducer_config.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from typing import Any, Optional + +from omegaconf import MISSING + +from nemo.collections.asr.models.configs.conv_transformer_config import RelTransformerBlock + +__all__ = ['TransformerTDecoder', 'JointNet', 'RNNTJoint', 'Greedy', 'Beam', 'Decoding', 'ModelDefaults'] + +@dataclass +class TransformerTDecoder: + _target_: str = 'nemo.collections.asr.modules.TransformerTDecoder' + vocab_size: int = MISSING + embed_size: int = MISSING + embed_dropout: float = MISSING + embed_proj_size: int = MISSING + norm_embed: bool = False + norm_embed_proj: bool = False + sos_idx: Optional[int] = MISSING + transformer_block: RelTransformerBlock = MISSING + blank_pos: str = MISSING + blank_as_pad: bool = True + mask_label_prob: float = 0.0 + mask_label_id: Optional[int] = None + ln_eps: float = 1e-5 + init_mode: str = 'xavier_uniform' + bias_init_mode: str = 'zero' + embedding_init_mode = 'xavier_uniform' + + +@dataclass +class JointNet: + joint_hidden: int = MISSING + activation: str = 'relu' + dropout: float = 0.0 + single_bias: bool = True + encoder_hidden: int = MISSING + pred_hidden: int = MISSING + + +@dataclass +class RNNTJoint: + _target_: str = 'nemo.collections.asr.modules.RNNTJoint' + + jointnet: JointNet = MISSING + + blank_pos: str = MISSING + + log_softmax: Any = None + + experimental_fuse_loss_wer: bool = False + fused_batch_size: int = 1 + + init_mode: str = 'xavier_uniform' + bias_init_mode: str = 'zero' + + +@dataclass +class Greedy: + max_symbols: int = MISSING + + +@dataclass +class Beam: + beam_size: int = MISSING + score_norm: bool = True + return_best_hypothesis: bool = True + + tsd_max_sym_exp_per_step: Optional[int] = 50 + alsd_max_target_len: Any = 1.0 + + nsc_max_timesteps_expansion: int = 1 + nsc_prefix_alpha: int = 1 + + beam_temperature: float = 1.0 + beam_combine_path: bool = True + beam_max_exp_step: int = 4 + beam_prune_exp: bool = True + beam_prune_exp_full: bool = True + beam_word_reward_ratio: float = 0.0 + + +@dataclass +class Decoding: + strategy: str = MISSING + greedy: Optional[Greedy] = None + beam: Optional[Beam] = None + + +@dataclass +class ModelDefaults: + enc_hidden: int = MISSING + pred_hidden: int = MISSING diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/ctc_bpe_models.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/ctc_bpe_models.py new file mode 100644 index 0000000000000000000000000000000000000000..84bb25256143e3fd4a33b422ee0cfa0b0da3b6ad --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/ctc_bpe_models.py @@ -0,0 +1,227 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from typing import Dict, Optional + +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer_bpe import WERBPE +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.core.classes.common import PretrainedModelInfo +from nemo.utils import logging, model_utils + +__all__ = ['EncDecCTCModelBPE'] + + +class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin): + """Encoder decoder CTC-based models with Byte Pair Encoding.""" + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + model = PretrainedModelInfo( + pretrained_model_name="ContextNet-192-WPE-1024-8x-Stride", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/ContextNet-192-WPE-1024-8x-Stride.nemo", + description="ContextNet initial implementation with CTC loss model trained on the Librispeech corpus and achieves a WER of 10.09% on test-other and 10.11% on dev-other.", + ) + result.append(model) + return result + + def __init__(self, cfg: DictConfig, trainer=None): + # Convert to Hydra 1.0 compatible DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + # Setup the tokenizer + self._setup_tokenizer(cfg.tokenizer) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + with open_dict(cfg): + cfg.decoder.vocabulary = ListConfig(list(vocabulary.values())) + + super().__init__(cfg=cfg, trainer=trainer) + + # Setup metric objects + self._wer = WERBPE( + tokenizer=self.tokenizer, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor']) + else: + augmentor = None + + shuffle = config['shuffle'] + + # Instantiate tarred dataset loader or normal dataset loader + if config.get('is_tarred', False): + if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( + 'manifest_filepath' in config and config['manifest_filepath'] is None + ): + logging.warning( + "Could not load dataset as `manifest_filepath` was None or " + f"`tarred_audio_filepaths` is None. Provided config : {config}" + ) + return None + + shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 + dataset = audio_to_text_dataset.get_tarred_bpe_dataset( + config=config, + tokenizer=self.tokenizer, + shuffle_n=shuffle_n, + global_rank=self.global_rank, + world_size=self.world_size, + augmentor=augmentor, + ) + shuffle = False + else: + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_text_dataset.get_bpe_dataset( + config=config, tokenizer=self.tokenizer, augmentor=augmentor + ) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer + + def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): + """ + Changes vocabulary of the tokenizer used during CTC decoding process. + Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_tokenizer_dir: Path to the new tokenizer directory. + new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, + whereas `wpe` is used for `BertTokenizer`. + + Returns: None + + """ + if not os.path.isdir(new_tokenizer_dir): + raise NotADirectoryError( + f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}' + ) + + if new_tokenizer_type.lower() not in ('bpe', 'wpe'): + raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`') + + tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type}) + + # Setup the tokenizer + self._setup_tokenizer(tokenizer_cfg) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + decoder_config = copy.deepcopy(self.decoder.to_config_dict()) + decoder_config.vocabulary = ListConfig(list(vocabulary.values())) + + decoder_num_classes = decoder_config['num_classes'] + + # Override number of classes if placeholder provided + logging.info( + "\nReplacing old number of classes ({}) with new number of classes - {}".format( + decoder_num_classes, len(vocabulary) + ) + ) + + decoder_config['num_classes'] = len(vocabulary) + + del self.decoder + self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config) + del self.loss + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + self._wer = WERBPE( + tokenizer=self.tokenizer, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.") diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/ctc_models.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/ctc_models.py new file mode 100644 index 0000000000000000000000000000000000000000..6e41cef3a771a9255ff0c75dfb60927ea828a6c8 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/ctc_models.py @@ -0,0 +1,643 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer import WER +from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel +from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType +from nemo.utils import logging + +__all__ = ['EncDecCTCModel'] + + +class EncDecCTCModel(ASRModel, ExportableEncDecModel): + """Base class for encoder decoder CTC-based models.""" + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + + model = PretrainedModelInfo( + pretrained_model_name="QuartzNet15x5Base-En", + description="QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_quartznet15x5/versions/1.0.0rc1/files/stt_en_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_zh_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_quartznet15x5/versions/1.0.0rc1/files/stt_zh_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_en_jasper10x5dr", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_jasper10x5dr", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_jasper10x5dr/versions/1.0.0rc1/files/stt_en_jasper10x5dr.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_ca_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ca_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ca_quartznet15x5/versions/1.0.0rc1/files/stt_ca_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_it_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_it_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_it_quartznet15x5/versions/1.0.0rc1/files/stt_it_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_fr_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_quartznet15x5/versions/1.0.0rc1/files/stt_fr_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_es_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_quartznet15x5/versions/1.0.0rc1/files/stt_es_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_de_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_de_quartznet15x5/versions/1.0.0rc1/files/stt_de_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_pl_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_pl_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_pl_quartznet15x5/versions/1.0.0rc1/files/stt_pl_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_ru_quartznet15x5", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ru_quartznet15x5", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ru_quartznet15x5/versions/1.0.0rc1/files/stt_ru_quartznet15x5.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="stt_zh_citrinet_512", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_512", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_citrinet_512/versions/1.0.0rc1/files/stt_zh_citrinet_512.nemo", + ) + results.append(model) + + return results + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + super().__init__(cfg=cfg, trainer=trainer) + self.preprocessor = EncDecCTCModel.from_config_dict(self._cfg.preprocessor) + self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder) + + with open_dict(self._cfg): + if "feat_in" not in self._cfg.decoder or ( + not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out') + ): + self._cfg.decoder.feat_in = self.encoder._feat_out + if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in: + raise ValueError("param feat_in of the decoder's config is not set!") + + self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder) + + self.loss = CTCLoss( + blank_id=self.decoder.blank_idx, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: + self.spec_augmentation = EncDecCTCModel.from_config_dict(self._cfg.spec_augment) + else: + self.spec_augmentation = None + + # Setup metric objects + self._wer = WER( + vocabulary=self.decoder.vocabulary, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + @torch.no_grad() + def transcribe( + self, paths2audio_files: List[str], batch_size: int = 4, logprobs=False, return_hypotheses: bool = False + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + # We will store transcriptions here + hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + self.encoder.freeze() + self.decoder.freeze() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + logits, logits_len, greedy_predictions = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + hypotheses.append(logits[idx][: logits_len[idx]]) + else: + current_hypotheses = self._wer.ctc_decoder_predictions_tensor( + greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses, + ) + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += current_hypotheses + + del greedy_predictions + del logits + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + if mode is True: + self.encoder.unfreeze() + self.decoder.unfreeze() + logging.set_verbosity(logging_level) + return hypotheses + + def change_vocabulary(self, new_vocabulary: List[str]): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + if self.decoder.vocabulary == new_vocabulary: + logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.decoder + self.decoder = EncDecCTCModel.from_config_dict(new_decoder_config) + del self.loss + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + self._wer = WER( + vocabulary=self.decoder.vocabulary, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor']) + else: + augmentor = None + + shuffle = config['shuffle'] + device = 'gpu' if torch.cuda.is_available() else 'cpu' + if config.get('use_dali', False): + device_id = self.local_rank if device == 'gpu' else None + dataset = audio_to_text_dataset.get_dali_char_dataset( + config=config, + shuffle=shuffle, + device_id=device_id, + global_rank=self.global_rank, + world_size=self.world_size, + preprocessor_cfg=self._cfg.preprocessor, + ) + return dataset + + # Instantiate tarred dataset loader or normal dataset loader + if config.get('is_tarred', False): + if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( + 'manifest_filepath' in config and config['manifest_filepath'] is None + ): + logging.warning( + "Could not load dataset as `manifest_filepath` was None or " + f"`tarred_audio_filepaths` is None. Provided config : {config}" + ) + return None + + shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 + dataset = audio_to_text_dataset.get_tarred_char_dataset( + config=config, + shuffle_n=shuffle_n, + global_rank=self.global_rank, + world_size=self.world_size, + augmentor=augmentor, + ) + shuffle = False + else: + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "outputs": NeuralType(('B', 'T', 'D'), LogprobsType()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "greedy_predictions": NeuralType(('B', 'T'), LabelsType()), + } + + @typecheck() + def forward( + self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None + ): + """ + Forward pass of the model. + + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + processed_signal: Tensor that represents a batch of processed audio signals, + of shape (B, D, T) that has undergone processing via some DALI preprocessor. + processed_signal_length: Vector of length B, that contains the individual lengths of the + processed audio sequences. + + Returns: + A tuple of 3 elements - + 1) The log probabilities tensor of shape [B, T, D]. + 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. + 3) The greedy token predictions of the model of shape [B, T] (via argmax) + """ + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) == False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, length=input_signal_length, + ) + + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len, _ = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + logits, encoded_len = self.decoder(encoder_output=encoded, lens=encoded_len, log_prob=False) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + with torch.no_grad(): + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + + return log_probs, encoded_len, greedy_predictions + + # PTL-specific methods + def training_step(self, batch, batch_nb): + signal, signal_len, transcript, transcript_len = batch + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + log_probs, encoded_len, predictions = self.forward( + processed_signal=signal, processed_signal_length=signal_len + ) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + tensorboard_logs = {'train_loss': loss_value, 'learning_rate': self._optimizer.param_groups[0]['lr']} + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + else: + log_every_n_steps = 1 + + # if (batch_nb + 1) % log_every_n_steps == 0: + # self._wer.update( + # predictions=predictions, + # targets=transcript, + # target_lengths=transcript_len, + # predictions_lengths=encoded_len, + # ) + # wer, _, _ = self._wer.compute() + # tensorboard_logs.update({'training_batch_wer': wer}) + + return {'loss': loss_value, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, decode_results=None): + signal, signal_len, transcript, transcript_len = batch + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + log_probs, encoded_len, predictions = self.forward( + processed_signal=signal, processed_signal_length=signal_len + ) + else: + log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len) + + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + self._wer.update( + predictions=predictions, targets=transcript, target_lengths=transcript_len, decode_results=decode_results, + predictions_lengths=encoded_len, log_prediction=batch_idx < 3) + wer, wer_num, wer_denom = self._wer.compute() + return { + 'val_loss': loss_value, + 'val_wer_num': wer_num, + 'val_wer_denom': wer_denom, + 'val_wer': wer, + } + + def test_step(self, batch, batch_idx, dataloader_idx=0): + decode_results = {} + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx, decode_results=decode_results) + test_logs = { + 'test_loss': logs['val_loss'], + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + 'test_wer': logs['val_wer'], + 'test_references': decode_results['references'], + 'test_hypotheses': decode_results['hypotheses'], + } + return test_logs + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + references = itertools.chain.from_iterable([x['test_references'] for x in outputs]) + hypotheses = itertools.chain.from_iterable([x['test_hypotheses'] for x in outputs]) + res = super().multi_test_epoch_end(outputs, dataloader_idx) + res['decode_results'] = (references, hypotheses) + return res + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'labels': self.decoder.vocabulary, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': True, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/flap_transducer_model.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/flap_transducer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4934082f866a825bd38864ae253594ab7bbc2126 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/flap_transducer_model.py @@ -0,0 +1,67 @@ +import contextlib + +import torch + +from nemo.collections.asr.models.st2vec.st2vec_model import apply_mask, create_padding_mask + + +class FLAPTransducerModel(torch.nn.Module): + + def __init__(self, text_projector, speech_pre_encoder, speech_backbone_transducer, *, auxiliary_decoder=None, latent_masking_cfg=None, + use_dpm_text_projector=False): + super().__init__() + + self.use_dpm_text_projector = use_dpm_text_projector + self.text_projector = text_projector + self.speech_pre_encoder = speech_pre_encoder + self.speech_backbone_transducer = speech_backbone_transducer + self.auxiliary_decoder = auxiliary_decoder + if latent_masking_cfg: + self.mask_emb = torch.nn.Parameter(torch.FloatTensor(latent_masking_cfg.mask_emb_dim).uniform_()) + self.text_masking = latent_masking_cfg.text_masking + self.speech_masking = latent_masking_cfg.speech_masking + else: + self.mask_emb = None + self.text_masking = None + self.speech_masking = None + + def forward(self, *, text, text_len, speech, speech_len, transcript, transcript_len, mask=True, freeze_auxiliary_decoder=False): + if text is not None: + assert speech is None + if self.use_dpm_text_projector: + latent, latent_len = self.text_projector(text, text_len) + else: + latent, latent_len, _ = self.text_projector(text, text_len, match_output_len=False) + unmasked_latent = latent + + if self.text_masking and mask: + if self.auxiliary_decoder is not None: + unmasked_latent = latent.clone() + + latent = latent.transpose(1, 2) + latent_mask = create_padding_mask(latent_len, latent.shape[1]) + latent, _, _ = apply_mask(self.text_masking, latent, latent_mask, self.mask_emb) + latent = latent.transpose(1, 2) + else: + latent, latent_len, _ = self.speech_pre_encoder(audio_signal=speech, length=speech_len, match_output_len=False) + unmasked_latent = latent + + if self.speech_masking and mask: + if self.auxiliary_decoder is not None: + unmasked_latent = latent.clone() + + latent = latent.transpose(1, 2) + latent_mask = create_padding_mask(latent_len, latent.shape[1]) + latent, _, _ = apply_mask(self.speech_masking, latent, latent_mask, self.mask_emb) + latent = latent.transpose(1, 2) + + results = self.speech_backbone_transducer(latent, latent_len, transcript, transcript_len) + + extra = {} + if self.auxiliary_decoder is not None: + with torch.no_grad() if freeze_auxiliary_decoder else contextlib.suppress(): + aux_log_prob, aux_log_prob_len = self.auxiliary_decoder(encoder_output=unmasked_latent, lens=latent_len) + extra['aux_decoder_log_probs'] = aux_log_prob + extra['aux_decoder_log_probs_len'] = aux_log_prob_len + + return results + (extra,) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/label_models.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/label_models.py new file mode 100644 index 0000000000000000000000000000000000000000..dab46f0ef438dac55f59022db4c5270d0f805b60 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/label_models.py @@ -0,0 +1,422 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import os +import pickle as pkl +from typing import Dict, List, Optional, Union + +import onnx +import torch +from omegaconf import DictConfig +from omegaconf.omegaconf import open_dict +from pytorch_lightning import Trainer + +from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataSet +from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss +from nemo.collections.asr.parts.features import WaveformFeaturizer +from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.common.losses import CrossEntropyLoss as CELoss +from nemo.collections.common.metrics import TopKClassificationAccuracy +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.neural_types import * +from nemo.utils import logging +from nemo.utils.export_utils import attach_onnx_to_onnx + +__all__ = ['EncDecSpeakerLabelModel', 'ExtractSpeakerEmbeddingsModel'] + + +class EncDecSpeakerLabelModel(ModelPT, Exportable): + """Encoder decoder class for speaker label models. + Model class creates training, validation methods for setting up data + performing model forward pass. + Expects config dict for + * preprocessor + * Jasper/Quartznet Encoder + * Speaker Decoder + """ + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + model = PretrainedModelInfo( + pretrained_model_name="SpeakerNet_recognition", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/SpeakerNet_recognition.nemo", + description="SpeakerNet_recognition model trained end-to-end for speaker recognition purposes with cross_entropy loss. It was trained on voxceleb 1, voxceleb 2 dev datasets and augmented with musan music and noise. Speaker Recognition model achieves 2.65% EER on voxceleb-O cleaned trial file", + ) + result.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="SpeakerNet_verification", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/SpeakerNet_verification.nemo", + description="SpeakerNet_verification model trained end-to-end for speaker verification purposes with arcface angular softmax loss. It was trained on voxceleb 1, voxceleb 2 dev datasets and augmented with musan music and noise. Speaker Verification model achieves 2.12% EER on voxceleb-O cleaned trial file", + ) + result.append(model) + + return result + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor) + self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder) + self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder) + if 'angular' in cfg.decoder and cfg.decoder['angular']: + logging.info("Training with Angular Softmax Loss") + scale = cfg.loss.scale + margin = cfg.loss.margin + self.loss = AngularSoftmaxLoss(scale=scale, margin=margin) + else: + logging.info("Training with Softmax-CrossEntropy loss") + self.loss = CELoss() + + self._accuracy = TopKClassificationAccuracy(top_k=[1], dist_sync_on_step=True) + + def __setup_dataloader_from_config(self, config: Optional[Dict]): + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor']) + else: + augmentor = None + + featurizer = WaveformFeaturizer( + sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor + ) + self.dataset = AudioToSpeechLabelDataSet( + manifest_filepath=config['manifest_filepath'], + labels=config['labels'], + featurizer=featurizer, + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + trim=config.get('trim_silence', True), + load_audio=config.get('load_audio', True), + time_length=config.get('time_length', 8), + ) + + return torch.utils.data.DataLoader( + dataset=self.dataset, + batch_size=config['batch_size'], + collate_fn=self.dataset.fixed_seq_collate_fn, + drop_last=config.get('drop_last', False), + shuffle=config['shuffle'], + num_workers=config.get('num_workers', 2), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_layer_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in train_data_layer_config: + train_data_layer_config['shuffle'] = True + self._train_dl = self.__setup_dataloader_from_config(config=train_data_layer_config) + + def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in val_data_layer_config: + val_data_layer_config['shuffle'] = False + val_data_layer_config['labels'] = self.dataset.labels + self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config) + + def setup_test_data(self, test_data_layer_params: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in test_data_layer_params: + test_data_layer_params['shuffle'] = False + if hasattr(self, 'dataset'): + test_data_layer_params['labels'] = self.dataset.labels + self.embedding_dir = test_data_layer_params.get('embedding_dir', './') + self.test_manifest = test_data_layer_params.get('manifest_filepath', None) + self._test_dl = self.__setup_dataloader_from_config(config=test_data_layer_params) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + audio_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), audio_eltype), + "input_signal_length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "logits": NeuralType(('B', 'D'), LogitsType()), + "embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()), + } + + @typecheck() + def forward(self, input_signal, input_signal_length): + processed_signal, processed_signal_len = self.preprocessor( + input_signal=input_signal, length=input_signal_length, + ) + + encoded, _ = self.encoder(audio_signal=processed_signal, length=processed_signal_len) + logits, embs = self.decoder(encoder_output=encoded) + return logits, embs + + # PTL-specific methods + def training_step(self, batch, batch_idx): + audio_signal, audio_signal_len, labels, _ = batch + logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + loss = self.loss(logits=logits, labels=labels) + + self.log('loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + + self._accuracy(logits=logits, labels=labels) + top_k = self._accuracy.compute() + for i, top_i in enumerate(top_k): + self.log(f'training_batch_accuracy_top@{i}', top_i) + + return {'loss': loss} + + def validation_step(self, batch, batch_idx, dataloader_idx: int = 0): + audio_signal, audio_signal_len, labels, _ = batch + logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + loss_value = self.loss(logits=logits, labels=labels) + acc_top_k = self._accuracy(logits=logits, labels=labels) + correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k + + return { + 'val_loss': loss_value, + 'val_correct_counts': correct_counts, + 'val_total_counts': total_counts, + 'val_acc_top_k': acc_top_k, + } + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + correct_counts = torch.stack([x['val_correct_counts'] for x in outputs]).sum(axis=0) + total_counts = torch.stack([x['val_total_counts'] for x in outputs]).sum(axis=0) + + self._accuracy.correct_counts_k = correct_counts + self._accuracy.total_counts_k = total_counts + topk_scores = self._accuracy.compute() + + logging.info("val_loss: {:.3f}".format(val_loss_mean)) + self.log('val_loss', val_loss_mean) + for top_k, score in zip(self._accuracy.top_k, topk_scores): + self.log('val_epoch_accuracy_top@{}'.format(top_k), score) + + return { + 'val_loss': val_loss_mean, + 'val_acc_top_k': topk_scores, + } + + def test_step(self, batch, batch_idx, dataloader_idx: int = 0): + audio_signal, audio_signal_len, labels, _ = batch + logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + loss_value = self.loss(logits=logits, labels=labels) + acc_top_k = self._accuracy(logits=logits, labels=labels) + correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k + + return { + 'test_loss': loss_value, + 'test_correct_counts': correct_counts, + 'test_total_counts': total_counts, + 'test_acc_top_k': acc_top_k, + } + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + correct_counts = torch.stack([x['test_correct_counts'] for x in outputs]).sum(axis=0) + total_counts = torch.stack([x['test_total_counts'] for x in outputs]).sum(axis=0) + + self._accuracy.correct_counts_k = correct_counts + self._accuracy.total_counts_k = total_counts + topk_scores = self._accuracy.compute() + + logging.info("test_loss: {:.3f}".format(test_loss_mean)) + self.log('test_loss', test_loss_mean) + for top_k, score in zip(self._accuracy.top_k, topk_scores): + self.log('test_epoch_accuracy_top@{}'.format(top_k), score) + + return { + 'test_loss': test_loss_mean, + 'test_acc_top_k': topk_scores, + } + + def setup_finetune_model(self, model_config: DictConfig): + """ + setup_finetune_model method sets up training data, validation data and test data with new + provided config, this checks for the previous labels set up during training from scratch, if None, + it sets up labels for provided finetune data from manifest files + + Args: + model_config: cfg which has train_ds, optional validation_ds, optional test_ds and + mandatory encoder and decoder model params + make sure you set num_classes correctly for finetune data + + Returns: None + + """ + if hasattr(self, 'dataset'): + scratch_labels = self.dataset.labels + else: + scratch_labels = None + + logging.info("Setting up data loaders with manifests provided from model_config") + + if 'train_ds' in model_config and model_config.train_ds is not None: + self.setup_training_data(model_config.train_ds) + else: + raise KeyError("train_ds is not found in model_config but you need it for fine tuning") + + if self.dataset.labels is None or len(self.dataset.labels) == 0: + raise ValueError(f'New labels must be non-empty list of labels. But I got: {self.dataset.labels}') + + if 'validation_ds' in model_config and model_config.validation_ds is not None: + self.setup_multiple_validation_data(model_config.validation_ds) + + if 'test_ds' in model_config and model_config.test_ds is not None: + self.setup_multiple_test_data(model_config.test_ds) + + if scratch_labels == self.dataset.labels: # checking for new finetune dataset labels + logging.warning( + "Trained dataset labels are same as finetune dataset labels -- continuing change of decoder parameters" + ) + elif scratch_labels is None: + logging.warning( + "Either you provided a dummy manifest file during training from scratch or you restored from a pretrained nemo file" + ) + + decoder_config = model_config.decoder + new_decoder_config = copy.deepcopy(decoder_config) + if new_decoder_config['num_classes'] != len(self.dataset.labels): + raise ValueError( + "number of classes provided {} is not same as number of different labels in finetuning data: {}".format( + new_decoder_config['num_classes'], len(self.dataset.labels) + ) + ) + + del self.decoder + self.decoder = EncDecSpeakerLabelModel.from_config_dict(new_decoder_config) + + with open_dict(self._cfg.decoder): + self._cfg.decoder = new_decoder_config + + logging.info(f"Changed decoder output to # {self.decoder._num_classes} classes.") + + def export( + self, + output: str, + input_example=None, + output_example=None, + verbose=False, + export_params=True, + do_constant_folding=True, + keep_initializers_as_inputs=False, + onnx_opset_version: int = 12, + try_script: bool = False, + set_eval: bool = True, + check_trace: bool = True, + use_dynamic_axes: bool = True, + ): + if input_example is not None or output_example is not None: + logging.warning( + "Passed input and output examples will be ignored and recomputed since" + " EncDecSpeakerModel consists of two separate models (encoder and decoder) with different" + " inputs and outputs." + ) + + qual_name = self.__module__ + '.' + self.__class__.__qualname__ + output1 = os.path.join(os.path.dirname(output), 'encoder_' + os.path.basename(output)) + output1_descr = qual_name + ' Encoder exported to ONNX' + encoder_onnx = self.encoder.export( + output1, + None, # computed by input_example() + None, + verbose, + export_params, + do_constant_folding, + keep_initializers_as_inputs, + onnx_opset_version, + try_script, + set_eval, + check_trace, + use_dynamic_axes, + ) + + output2 = os.path.join(os.path.dirname(output), 'decoder_' + os.path.basename(output)) + output2_descr = qual_name + ' Decoder exported to ONNX' + decoder_onnx = self.decoder.export( + output2, + None, # computed by input_example() + None, + verbose, + export_params, + do_constant_folding, + keep_initializers_as_inputs, + onnx_opset_version, + try_script, + set_eval, + check_trace, + use_dynamic_axes, + ) + + output_model = attach_onnx_to_onnx(encoder_onnx, decoder_onnx, "SL") + output_descr = qual_name + ' Encoder+Decoder exported to ONNX' + onnx.save(output_model, output) + return ([output, output1, output2], [output_descr, output1_descr, output2_descr]) + + +class ExtractSpeakerEmbeddingsModel(EncDecSpeakerLabelModel): + """ + This Model class facilitates extraction of speaker embeddings from a pretrained model. + Respective embedding file is saved in self.embedding dir passed through cfg + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + + def test_step(self, batch, batch_ix): + audio_signal, audio_signal_len, labels, slices = batch + _, embs = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + return {'embs': embs, 'labels': labels, 'slices': slices} + + def test_epoch_end(self, outputs): + embs = torch.cat([x['embs'] for x in outputs]) + slices = torch.cat([x['slices'] for x in outputs]) + emb_shape = embs.shape[-1] + embs = embs.view(-1, emb_shape).cpu().numpy() + out_embeddings = {} + start_idx = 0 + with open(self.test_manifest, 'r') as manifest: + for idx, line in enumerate(manifest.readlines()): + line = line.strip() + dic = json.loads(line) + structure = dic['audio_filepath'].split('/')[-3:] + uniq_name = '@'.join(structure) + if uniq_name in out_embeddings: + raise KeyError("Embeddings for label {} already present in emb dictionary".format(uniq_name)) + num_slices = slices[idx] + end_idx = start_idx + num_slices + out_embeddings[uniq_name] = embs[start_idx:end_idx].mean(axis=0) + start_idx = end_idx + + embedding_dir = os.path.join(self.embedding_dir, 'embeddings') + if not os.path.exists(embedding_dir): + os.mkdir(embedding_dir) + + prefix = self.test_manifest.split('/')[-1].split('.')[-2] + + name = os.path.join(embedding_dir, prefix) + pkl.dump(out_embeddings, open(name + '_embeddings.pkl', 'wb')) + logging.info("Saved embedding files to {}".format(embedding_dir)) + + return {} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/rnnt_bpe_models.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/rnnt_bpe_models.py new file mode 100644 index 0000000000000000000000000000000000000000..7808120d22694ddd851cafb93bc043372ef0e5e5 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/rnnt_bpe_models.py @@ -0,0 +1,267 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from typing import Dict, Optional + +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.rnnt import RNNTLoss +from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEWER, RNNTBPEDecoding +from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.core.classes.common import PretrainedModelInfo +from nemo.utils import logging, model_utils + +try: + import warprnnt_pytorch as warprnnt + + WARP_RNNT_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + WARP_RNNT_AVAILABLE = False + + +class EncDecRNNTBPEModel(EncDecRNNTModel, ASRBPEMixin): + """Base class for encoder decoder RNNT-based models with subword tokenization.""" + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + return result + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Required loss function + if not WARP_RNNT_AVAILABLE: + raise ImportError( + "Could not import `warprnnt_pytorch`.\n" + "Please visit https://github.com/HawkAaron/warp-transducer " + "and follow the steps in the readme to build and install the " + "pytorch bindings for RNNT Loss, or use the provided docker " + "container that supports RNN-T loss." + ) + + # Convert to Hydra 1.0 compatible DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + # Tokenizer is necessary for this model + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + if not isinstance(cfg, DictConfig): + cfg = OmegaConf.create(cfg) + + # Setup the tokenizer + self._setup_tokenizer(cfg.tokenizer) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + with open_dict(cfg): + cfg.labels = ListConfig(list(vocabulary)) + + with open_dict(cfg.decoder): + cfg.decoder.vocab_size = len(vocabulary) + + with open_dict(cfg.joint): + cfg.joint.num_classes = len(vocabulary) + cfg.joint.vocabulary = ListConfig(list(vocabulary)) + cfg.joint.jointnet.encoder_hidden = cfg.model_defaults.enc_hidden + cfg.joint.jointnet.pred_hidden = cfg.model_defaults.pred_hidden + + super().__init__(cfg=cfg, trainer=trainer) + + # Setup decoding object + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) + + # Setup wer object + self.wer = RNNTBPEWER( + decoding=self.decoding, batch_dim_index=0, use_cer=False, log_prediction=True, dist_sync_on_step=True + ) + + # Setup fused Joint step if flag is set + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + def change_vocabulary( + self, new_tokenizer_dir: str, new_tokenizer_type: str, decoding_cfg: Optional[DictConfig] = None + ): + """ + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_tokenizer_dir: Directory path to tokenizer. + new_tokenizer_type: Type of tokenizer. Can be either `bpe` or `wpe`. + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + + Returns: None + + """ + if not os.path.isdir(new_tokenizer_dir): + raise NotADirectoryError( + f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}' + ) + + if new_tokenizer_type.lower() not in ('bpe', 'wpe'): + raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`') + + tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type}) + + # Setup the tokenizer + self._setup_tokenizer(tokenizer_cfg) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + joint_config = self.joint.to_config_dict() + new_joint_config = copy.deepcopy(joint_config) + new_joint_config['vocabulary'] = ListConfig(list(vocabulary.values())) + new_joint_config['num_classes'] = len(vocabulary) + del self.joint + self.joint = EncDecRNNTBPEModel.from_config_dict(new_joint_config) + + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config.vocab_size = len(vocabulary) + del self.decoder + self.decoder = EncDecRNNTBPEModel.from_config_dict(new_decoder_config) + + del self.loss + self.loss = RNNTLoss(blank_idx=self.decoder.blank_idx) + + if decoding_cfg is None: + # Assume same decoding config as before + decoding_cfg = self.cfg.decoding + + self.decoding = RNNTBPEDecoding( + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) + + self.wer = RNNTBPEWER( + decoding=self.decoding, + batch_dim_index=self.wer.batch_dim_index, + use_cer=self.wer.use_cer, + log_prediction=self.wer.log_prediction, + dist_sync_on_step=True, + ) + + # Setup fused Joint step + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + # Update config + with open_dict(self.cfg.joint): + self.cfg.joint = new_joint_config + + with open_dict(self.cfg.decoder): + self.cfg.decoder = new_decoder_config + + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor']) + else: + augmentor = None + + shuffle = config['shuffle'] + + # Instantiate tarred dataset loader or normal dataset loader + if config.get('is_tarred', False): + if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( + 'manifest_filepath' in config and config['manifest_filepath'] is None + ): + logging.warning( + "Could not load dataset as `manifest_filepath` was None or " + f"`tarred_audio_filepaths` is None. Provided config : {config}" + ) + return None + + shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 + dataset = audio_to_text_dataset.get_tarred_bpe_dataset( + config=config, + tokenizer=self.tokenizer, + shuffle_n=shuffle_n, + global_rank=self.global_rank, + world_size=self.world_size, + augmentor=augmentor, + ) + shuffle = False + else: + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_text_dataset.get_bpe_dataset( + config=config, tokenizer=self.tokenizer, augmentor=augmentor + ) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/rnnt_models.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/rnnt_models.py new file mode 100644 index 0000000000000000000000000000000000000000..09335fff9078728540b137b6951bac935944431b --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/rnnt_models.py @@ -0,0 +1,624 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, open_dict +from pytorch_lightning import Trainer + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs +from nemo.collections.asr.losses.rnnt import RNNTLoss +from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.modules.rnnt_abstract import AbstractRNNTDecoder +from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType, \ + VoidType +from nemo.utils import logging + +try: + import warprnnt_pytorch as warprnnt + + WARP_RNNT_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + WARP_RNNT_AVAILABLE = False + + +class EncDecRNNTModel(ASRModel): + """Base class for encoder decoder RNNT-based models.""" + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + return result + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Required loss function + if not WARP_RNNT_AVAILABLE: + raise ImportError( + "Could not import `warprnnt_pytorch`.\n" + "Please visit https://github.com/HawkAaron/warp-transducer " + "and follow the steps in the readme to build and install the " + "pytorch bindings for RNNT Loss, or use the provided docker " + "container that supports RNN-T loss." + ) + + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + # Update config values required by components dynamically + with open_dict(cfg.decoder): + cfg.decoder.vocab_size = len(cfg.labels) + + with open_dict(cfg.joint): + cfg.joint.num_classes = len(cfg.labels) + cfg.joint.vocabulary = cfg.labels + cfg.joint.jointnet.encoder_hidden = cfg.model_defaults.enc_hidden + cfg.joint.jointnet.pred_hidden = cfg.model_defaults.pred_hidden + cfg.joint.blank_pos = cfg.decoder.blank_pos + + super().__init__(cfg=cfg, trainer=trainer) + + # Initialize components + self.preprocessor = EncDecRNNTModel.from_config_dict(self.cfg.preprocessor) + self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder) + + + self.decoder: AbstractRNNTDecoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) + self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) + self.loss = RNNTLoss(blank_idx=self.decoder.blank_idx) + + if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None: + self.spec_augmentation = EncDecRNNTModel.from_config_dict(self.cfg.spec_augment) + else: + self.spec_augmentation = None + + # Setup decoding objects + self.decoding = RNNTDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + ) + # Setup WER calculation + self.wer = RNNTWER( + decoding=self.decoding, batch_dim_index=0, use_cer=False, log_prediction=True, dist_sync_on_step=True + ) + + # Whether to compute loss during evaluation + if 'compute_eval_loss' in self.cfg: + self.compute_eval_loss = self.cfg.compute_eval_loss + else: + self.compute_eval_loss = True + + # Setup fused Joint step if flag is set + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + # setting up the variational noise for the decoder + if hasattr(self.cfg, 'variational_noise'): + self._optim_variational_noise_std = self.cfg['variational_noise'].get('std', None) + self._optim_variational_noise_start = self.cfg['variational_noise'].get('start_step', 0) + else: + self._optim_variational_noise_std = 0 + self._optim_variational_noise_start = 0 + + @torch.no_grad() + def transcribe(self, paths2audio_files: List[str], batch_size: int = 4) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + + Returns: + + A list of transcriptions in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + # We will store transcriptions here + hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + try: + # Switch model to evaluation mode + self.eval() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in temporary_datalayer: + encoded, encoded_len = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + hypotheses += self.decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len) + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + logging.set_verbosity(logging_level) + return hypotheses + + def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[DictConfig] = None): + """ + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + + Returns: None + + """ + if self.joint.vocabulary == new_vocabulary: + logging.warning(f"Old {self.joint.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + + joint_config = self.joint.to_config_dict() + new_joint_config = copy.deepcopy(joint_config) + new_joint_config['vocabulary'] = new_vocabulary + new_joint_config['num_classes'] = len(new_vocabulary) + del self.joint + self.joint = EncDecRNNTModel.from_config_dict(new_joint_config) + + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config.vocab_size = len(new_vocabulary) + del self.decoder + self.decoder = EncDecRNNTModel.from_config_dict(new_decoder_config) + + del self.loss + self.loss = RNNTLoss(blank_idx=self.decoder.blank_idx) + + if decoding_cfg is None: + # Assume same decoding config as before + decoding_cfg = self.cfg.decoding + + self.decoding = RNNTDecoding( + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + ) + + self.wer = RNNTWER( + decoding=self.decoding, + batch_dim_index=self.wer.batch_dim_index, + use_cer=self.wer.use_cer, + log_prediction=self.wer.log_prediction, + dist_sync_on_step=True, + ) + + # Setup fused Joint step + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + # Update config + with open_dict(self.cfg.joint): + self.cfg.joint = new_joint_config + + with open_dict(self.cfg.decoder): + self.cfg.decoder = new_decoder_config + + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor']) + else: + augmentor = None + + shuffle = config['shuffle'] + device = 'gpu' if torch.cuda.is_available() else 'cpu' + if config.get('use_dali', False): + device_id = self.local_rank if device == 'gpu' else None + dataset = audio_to_text_dataset.get_dali_char_dataset( + config=config, + shuffle=shuffle, + device_id=device_id, + global_rank=self.global_rank, + world_size=self.world_size, + preprocessor_cfg=self._cfg.preprocessor, + ) + return dataset + + # Instantiate tarred dataset loader or normal dataset loader + if config.get('is_tarred', False): + if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( + 'manifest_filepath' in config and config['manifest_filepath'] is None + ): + logging.warning( + "Could not load dataset as `manifest_filepath` was None or " + f"`tarred_audio_filepaths` is None. Provided config : {config}" + ) + return None + + shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 + dataset = audio_to_text_dataset.get_tarred_char_dataset( + config=config, + shuffle_n=shuffle_n, + global_rank=self.global_rank, + world_size=self.world_size, + augmentor=augmentor, + ) + shuffle = False + else: + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "extra": NeuralType(elements_type=VoidType()), + } + + @typecheck() + def forward( + self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None + ): + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) is False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, length=input_signal_length, + ) + + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len, extra = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + return encoded, encoded_len, extra + + # PTL-specific methods + def training_step(self, batch, batch_nb): + signal, signal_len, transcript, transcript_len = batch + + # forward() only performs encoder forward + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len, extra = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len, extra = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + # During training, loss must be computed, so decoder forward is necessary + decoder, target_length = self.decoder(targets=transcript, target_length=transcript_len) + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + sample_id = self._trainer.global_step + + else: + log_every_n_steps = 1 + sample_id = batch_nb + + tensorboard = self.logger.experiment + + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + # Compute full joint and loss + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + + tensorboard.add_scalar('train_loss', loss_value, self.global_step) + tensorboard.add_scalar('learning_rate', self._optimizer.param_groups[0]['lr'], self.global_step) + + # if (sample_id + 1) % log_every_n_steps == 0: + # self.wer.update(encoded, encoded_len, transcript, transcript_len) + # _, scores, words = self.wer.compute() + # tensorboard_logs.update({'training_batch_wer': scores.float() / words}) + + else: + # If experimental fused Joint-Loss-WER is used + if (sample_id + 1) % log_every_n_steps == 0: + compute_wer = True + else: + compute_wer = False + + # Fused joint step + loss_value, wer, _, _ = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoder, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=transcript_len, + compute_wer=compute_wer, + ) + + tensorboard.add_scalar('train_loss', loss_value, self.global_step) + tensorboard.add_scalar('learning_rate', self._optimizer.param_groups[0]['lr'], self.global_step) + + if compute_wer: + tensorboard.add_scalar('training_batch_wer', wer, self.global_step) + + if extra[0] is not None: + adaptive_ffn_loss = extra[0] + tensorboard.add_scalar('train_adaffn_loss', adaptive_ffn_loss, self.global_step) + loss_value = loss_value + adaptive_ffn_loss + + return {'loss': loss_value} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, decode_results=None): + signal, signal_len, transcript, transcript_len = batch + + # forward() only performs encoder forward + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len, extra = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len, extra = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + tensorboard_logs = {} + + log_prediction = batch_idx < 3 + + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + if self.compute_eval_loss: + decoder, target_length = self.decoder(targets=transcript, target_length=transcript_len) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + + tensorboard_logs['val_loss'] = loss_value + + self.wer.update(encoded, encoded_len, transcript, transcript_len, decode_results=decode_results, log_prediction=log_prediction) + wer, wer_num, wer_denom = self.wer.compute() + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + + else: + # If experimental fused Joint-Loss-WER is used + compute_wer = True + + if self.compute_eval_loss: + decoded, target_len = self.decoder(targets=transcript, target_length=transcript_len) + else: + decoded = None + target_len = transcript_len + + # Fused joint step + loss_value, wer, wer_num, wer_denom = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoded, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=target_len, + compute_wer=compute_wer, + decode_results=decode_results, + log_prediction=log_prediction + ) + + if loss_value is not None: + tensorboard_logs['val_loss'] = loss_value + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + + if extra[0] is not None: + tensorboard_logs['val_adaffn_loss'] = extra[0] + + return tensorboard_logs + + def test_step(self, batch, batch_idx, dataloader_idx=0): + decode_results = {} + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx, decode_results=decode_results) + test_logs = { + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + # 'test_wer': logs['val_wer'], + 'test_references': decode_results['references'], + 'test_hypotheses': decode_results['hypotheses'], + } + if 'val_loss' in logs: + test_logs['test_loss'] = logs['val_loss'] + if 'val_adaffn_loss' in logs: + test_logs['test_adaffn_loss'] = logs['val_adaffn_loss'] + return test_logs + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + if self.compute_eval_loss: + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + val_loss_log = {'val_loss': val_loss_mean} + else: + val_loss_log = {} + wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {**val_loss_log, 'val_wer': wer_num.float() / wer_denom} + if 'val_adaffn_loss' in outputs[0]: + adaffn_loss = torch.stack([x['val_adaffn_loss'] for x in outputs]).mean() + tensorboard_logs['val_adaffn_loss'] = adaffn_loss + return {**val_loss_log, 'log': tensorboard_logs} + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + if self.compute_eval_loss: + test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + test_loss_log = {'test_loss': test_loss_mean} + else: + test_loss_log = {} + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {**test_loss_log, 'test_wer': wer_num.float() / wer_denom} + references = itertools.chain.from_iterable([x['test_references'] for x in outputs]) + hypotheses = itertools.chain.from_iterable([x['test_hypotheses'] for x in outputs]) + if 'test_adaffn_loss' in outputs[0]: + adaffn_loss = torch.stack([x['test_adaffn_loss'] for x in outputs]).mean() + tensorboard_logs['test_adaffn_loss'] = adaffn_loss + return {**test_loss_log, 'log': tensorboard_logs, 'decode_results': (references, hypotheses)} + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'labels': self.joint.vocabulary, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': True, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer + + def on_after_backward(self): + super().on_after_backward() + if self._optim_variational_noise_std > 0 and self.global_step >= self._optim_variational_noise_start: + for param_name, param in self.decoder.named_parameters(): + if param.grad is not None: + noise = torch.normal( + mean=0.0, + std=self._optim_variational_noise_std, + size=param.size(), + device=param.device, + dtype=param.dtype, + ) + param.grad.data.add_(noise) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/ctc_finetune.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/ctc_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..64bc3db7c994d633b903015fc93b8f55cfe96197 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/ctc_finetune.py @@ -0,0 +1,547 @@ +import contextlib +import copy +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf, open_dict, ListConfig +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer import WER, WER_phone +from nemo.collections.asr.metrics.wer_bpe import WERBPE +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.parts.perturb import process_augmentations, RandomNoisePerturbation, AudioAugmentor +from nemo.utils import logging + + +class CTCFinetuneModel(ASRModel): + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + self.use_bpe = cfg.tokenizer is not None + if self.use_bpe: + from nemo.collections.asr.parts.mixins import ASRBPEMixin + self.bpe = ASRBPEMixin() + self.bpe._setup_tokenizer(cfg.tokenizer, register_artifact=False) + self.tokenizer = self.bpe.tokenizer + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + assert len(cfg.decoder.vocabulary) == 0 + with open_dict(cfg): + cfg.decoder.vocabulary = ListConfig(list(vocabulary.values())) + else: + self.label_type = cfg.get('label_type', 'char') + self.add_end_space = cfg.add_end_space + self.lang = cfg.lang + + super().__init__(cfg=cfg, trainer=trainer) + + if self._cfg.encoder_type == 'spec2vec': + from nemo.collections.asr.models.spec2vec.spec2vec_model import Spec2VecEncoder + self.encoder = Spec2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'spec2vec_encoder.' + elif self._cfg.encoder_type == 'st': + from nemo.collections.asr.models.st2vec.st2vec_model import ST2VecEncoder + self.encoder = ST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + elif self._cfg.encoder_type == 'feat_st': + from nemo.collections.asr.models.st2vec.st2vec_model import FeatST2VecEncoder + self.encoder = FeatST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + else: + assert self._cfg.encoder_type == 'wav2vec' + from nemo.collections.asr.modules.wav2vec_encoder import Wav2VecEncoderModel + self.encoder = Wav2VecEncoderModel(self._cfg.encoder) + encoder_param_prefix = None + if cfg.pretrain_chkpt_path is not None: + self.init_encoder_from_pretrain_model(self.encoder, encoder_param_prefix, cfg.pretrain_chkpt_path) + if self._cfg.encoder_type == 'st': + self.encoder.remove_pretraining_modules(use_teacher_encoder=self._cfg.use_teacher_encoder) + else: + self.encoder.remove_pretraining_modules() + + self.decoder = self.from_config_dict(self._cfg.decoder) + + self.freeze_finetune_updates = self._cfg.freeze_finetune_updates + + self.loss = CTCLoss( + blank_id=self.decoder.blank_idx, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + # Setup metric objects + if self.use_bpe: + self._wer = WERBPE( + tokenizer=self.tokenizer, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + lang=self.lang, + ) + else: + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + strip_end_space=self.add_end_space, + ) + + @torch.no_grad() + def transcribe( + self, paths2audio_files: List[str], batch_size: int = 4, logprobs=False, return_hypotheses: bool = False + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + # We will store transcriptions here + hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + self.encoder.freeze() + self.decoder.freeze() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + logits, logits_len, greedy_predictions = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + hypotheses.append(logits[idx][: logits_len[idx]]) + else: + current_hypotheses = self._wer.ctc_decoder_predictions_tensor( + greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses, + ) + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += current_hypotheses + + del greedy_predictions + del logits + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + if mode is True: + self.encoder.unfreeze() + self.decoder.unfreeze() + logging.set_verbosity(logging_level) + return hypotheses + + def change_vocabulary(self, new_vocabulary: List[str]): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + assert not self.use_bpe + if self.decoder.vocabulary == new_vocabulary: + logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.decoder + self.decoder = self.from_config_dict(new_decoder_config) + del self.loss + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict], noise_perturb_config): + if noise_perturb_config is not None: + noise_perturb = RandomNoisePerturbation(**noise_perturb_config) + augmentor = AudioAugmentor(perturbations=[(1.0, noise_perturb)]) + else: + augmentor = None + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + if self.add_end_space: + config['parser_add_end_space'] = self.add_end_space + + if self.use_bpe: + dataset = audio_to_text_dataset.get_bpe_dataset(config=config, tokenizer=self.tokenizer, + augmentor=augmentor) + else: + if self.label_type == 'char': + dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) + elif self.label_type == 'phone': + dataset = audio_to_text_dataset.get_phone_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config, + noise_perturb_config=self._cfg['noise_perturb']) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, noise_perturb_config=None) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config, noise_perturb_config=None) + + def optim_param_groups(self): + return [{'params': self.encoder.parameters(), 'weight_decay': 0.0}, + {'params': self.decoder.parameters()}] + + def forward(self, input_signal, input_signal_length, global_step): + """ + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + """ + ft = False if global_step is None else self.freeze_finetune_updates <= global_step + with torch.no_grad() if not ft else contextlib.suppress(): + encoded, encoded_len = self.encoder(input_signal, input_signal_length, None, None, + mask=self.training, features_only=True) + + # [B, T, D] => [B, D, T] + encoded = encoded.transpose(1, 2) + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = encoded_len.max() + if encoded.shape[2] != max_output_len: + encoded = encoded.narrow(dim=2, start=0, length=max_output_len).contiguous() + + logits, encoded_len = self.decoder(encoder_output=encoded, lens=encoded_len, log_prob=False) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + with torch.no_grad(): + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + + return log_probs, encoded_len, greedy_predictions, logits + + def training_step(self, batch, batch_nb): + signal, signal_len, transcript, transcript_len = batch + log_probs, encoded_len, predictions, _ = self(input_signal=signal, input_signal_length=signal_len, + global_step=self.trainer.global_step) + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + tensorboard_logs = {'train_loss': loss_value, 'learning_rate': self._optimizer.param_groups[0]['lr']} + return {'loss': loss_value, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, decode_results=None): + signal, signal_len, transcript, transcript_len = batch + log_probs, encoded_len, predictions, logits = self(input_signal=signal, input_signal_length=signal_len, + global_step=None) + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + self._wer.update( + predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + log_prediction=batch_idx < 3, decode_results=decode_results) + wer, wer_num, wer_denom = self._wer.compute() + return { + 'val_loss': loss_value, + 'val_wer_num': wer_num, + 'val_wer_denom': wer_denom, + 'val_wer': wer, + 'val_logprob': log_probs.cpu().numpy(), + 'val_logprob_len': encoded_len.cpu().numpy(), + 'val_logits': logits.cpu().numpy(), + } + + def test_step(self, batch, batch_idx, dataloader_idx=0): + decode_results = {} + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx, decode_results=decode_results) + test_logs = { + 'test_loss': logs['val_loss'], + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + 'test_wer': logs['val_wer'], + 'test_references': decode_results['references'], + 'test_hypotheses': decode_results['hypotheses'], + 'test_logprob': logs['val_logprob'], + 'test_logprob_len': logs['val_logprob_len'], + 'test_logits': logs['val_logits'], + } + return test_logs + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + assert not self.use_bpe + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'labels': self.decoder.vocabulary, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': True, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), + noise_perturb_config=None) + return temporary_datalayer + + @classmethod + def init_encoder_from_pretrain_model( + cls, + encoder, + encoder_param_prefix, + checkpoint_path, + *, + map_location=None, + strict: bool = True): + try: + cls._set_model_restore_state(is_being_restored=True) + + from pytorch_lightning.utilities.cloud_io import load as pl_load + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # for past checkpoint need to add the new key + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint + pretrain_cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + + # give model a chance to load something + # model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + if encoder_param_prefix is not None: + encoder_state = {k[len(encoder_param_prefix):]: v for k, v in checkpoint['state_dict'].items() if + k.startswith(encoder_param_prefix)} + else: + encoder_state = checkpoint['state_dict'] + encoder.load_state_dict(encoder_state, strict=strict) + finally: + cls._set_model_restore_state(is_being_restored=False) + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom} + references = itertools.chain.from_iterable([x['test_references'] for x in outputs]) + hypotheses = itertools.chain.from_iterable([x['test_hypotheses'] for x in outputs]) + test_logprob = [x['test_logprob'] for x in outputs] + test_logprob_len = [x['test_logprob_len'] for x in outputs] + test_logits = [x['test_logits'] for x in outputs] + return {'test_loss': val_loss_mean, 'log': tensorboard_logs, 'decode_results': (references, hypotheses), + 'test_logprob': test_logprob, 'test_logprob_len': test_logprob_len, 'test_logits': test_logits} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/ctc_finetune_model.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/ctc_finetune_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8e527880975297c7c54adc33d52ccf20af94d6aa --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/ctc_finetune_model.py @@ -0,0 +1,40 @@ +import contextlib +import torch + +from nemo.core import Serialization + + +class CTCFinetuneModel(torch.nn.Module): + def __init__(self, encoder, cfg): + super().__init__() + + self.encoder = encoder + + self.decoder = Serialization.from_config_dict(cfg.decoder) + + self.freeze_finetune_updates = cfg.freeze_finetune_updates + + def forward(self, input_signal, input_signal_length, global_step): + ft = False if global_step is None else self.freeze_finetune_updates <= global_step + with torch.no_grad() if not ft else contextlib.suppress(): + encoded, encoded_len = self.encoder(input_signal, input_signal_length, None, None, + mask=self.training, features_only=True) + + # [B, T, D] => [B, D, T] + encoded = encoded.transpose(1, 2) + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = encoded_len.max() + if encoded.shape[2] != max_output_len: + encoded = encoded.narrow(dim=2, start=0, length=max_output_len).contiguous() + + logits, encoded_len = self.decoder(encoder_output=encoded, lens=encoded_len, log_prob=False) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + # with torch.no_grad(): + # greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + + return log_probs, encoded_len, logits + diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b540ced56efa52b05388a696ed04161032ecb239 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_config.py @@ -0,0 +1,144 @@ +from typing import Optional, List, Any + +from dataclasses import field, dataclass +from omegaconf import MISSING + +from nemo.collections.asr.models.configs.common_config import Conv2dBlock, Conv1dNormAct, DatasetConfig, Tokenizer +from nemo.collections.asr.models.wav2vec.wav2vec_config import LossConfig, QuantizerConfig, Wav2VecTransformerConfig, \ + Wav2VecMaskingConfig +from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessorConfig +from nemo.core.config.modelPT import ModelConfig + + +@dataclass +class ConvTransformerBlock: + conv_layers: List[Conv1dNormAct] = MISSING + transformer_block: Optional[Wav2VecTransformerConfig] = None + + +@dataclass +class FeatureEncoderConfig: + _target_: str = 'nemo.collections.asr.parts.spec2vec.FeatureEncoder' + feat_in: int = MISSING + use_conv_mask: bool = MISSING + conv2d_block: Optional[Conv2dBlock] = MISSING + conv_transformer_blocks: List[ConvTransformerBlock] = MISSING + use_tf_pad: bool = True + ln_eps: float = 1e-5 + + +@dataclass +class ProjectorConfig: + input_dim: Optional[int] = None + output_dim: Optional[int] = None + use_conv_mask: bool = True + use_tf_pad: bool = True + ln_eps: float = 1e-5 + conv_layers: Optional[List[Conv1dNormAct]] = None + transformer: Optional[Wav2VecTransformerConfig] = None + + +@dataclass +class Spec2VecEncoderConfig: + noisy_spec2vec: bool = False + + preprocessor: AudioToMelSpectrogramPreprocessorConfig = MISSING + + quantizer: QuantizerConfig = QuantizerConfig() + feature_encoder: FeatureEncoderConfig = FeatureEncoderConfig() + freeze_feature_encoder: bool = False + targets_grad_update_inverval: int = 1 + transformer_encoder: Wav2VecTransformerConfig = Wav2VecTransformerConfig() + masking: Optional[Wav2VecMaskingConfig] = None + learnable_mask: bool = True + + dropout_input: float = field(default=0.1, metadata={'help': 'Dropout applied to input raw features'}) + dropout_features: float = field( + default=0.1, metadata={'help': 'Dropout applied to the features generator by convolutions'} + ) + final_hidden_dim: Optional[int] = None + dropout_final: float = 0.1 + final_dim: int = field(default=0, metadata={'help': 'Project final representations and targets to this dimension'}) + n_negatives: int = field( + default=100, metadata={'help': 'Number of negatives to sample from the same audio sample'} + ) + cross_sample_negatives: int = field( + default=0, metadata={'help': 'Number of negatives to sample from any sample in the batch'} + ) + codebook_negatives: int = field(default=0, metadata={'help': 'Number of negative examples in codebook'}) + negatives_from_everywhere: bool = field( + default=False, metadata={'help': 'Sample negatives from everywhere, not just masked states'} + ) + + +@dataclass +class Spec2VecPretrainModelConfig(ModelConfig): + spec2vec_encoder: Spec2VecEncoderConfig = MISSING + + logit_temp: float = field(default=0.1, metadata={'help': 'Temperature to divide logits by'}) + loss: LossConfig = LossConfig() + + expected_gpu_num: int = 1 + + +@dataclass +class NoisePerturbConfig: + manifest_path: List[str] + min_snr_db: float + max_snr_db: float + max_gain_db: float = 300.0 + ratio: float = 1.0 + target_sr: int = 16000 + data_dir: str = '' + cache_noise: bool = False + + +@dataclass +class Spec2VecCTCFinetuneModelConfig(ModelConfig): + pretrain_chkpt_path: Optional[str] = MISSING + + encoder_type: str = 'spec2vec' + encoder: Any = MISSING + decoder: Any = MISSING + + labels: Optional[List[str]] = None + tokenizer: Optional[Tokenizer] = None + add_end_space: bool = False + lang: str = 'en' + + freeze_finetune_updates: int = 0 + + noise_perturb: Optional[NoisePerturbConfig] = None + + # Dataset configs + train_ds: DatasetConfig = MISSING + validation_ds: DatasetConfig = MISSING + test_ds: DatasetConfig = MISSING + + expected_gpu_num: int = MISSING + + +@dataclass +class Wav2VecCTCFinetuneModelConfig(Spec2VecCTCFinetuneModelConfig): + encoder_type: str = 'wav2vec' + + +@dataclass +class ST2VecCTCFinetuneModelConfig(Spec2VecCTCFinetuneModelConfig): + label_type: Optional[str] = None # ['char', 'phone','unit', 'bpe'] + encoder_type: str = 'st' + use_teacher_encoder: bool = False + + +@dataclass +class ST2VecVQCTCFinetuneModelConfig(Spec2VecCTCFinetuneModelConfig): + label_type: str = None # ['char', 'phone','bpe'] + encoder_type: str = 'st' + quantizer: QuantizerConfig = MISSING + quant_ppl_loss_weight: float = MISSING + use_teacher_encoder: bool = False + + +@dataclass +class FeatST2VecCTCFinetuneModelConfig(Spec2VecCTCFinetuneModelConfig): + encoder_type: str = 'feat_st' diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_model.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_model.py new file mode 100644 index 0000000000000000000000000000000000000000..00a3fdefa845575ebb5739e89bc364541058fe94 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_model.py @@ -0,0 +1,498 @@ +import contextlib + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import nn + +from nemo.collections.asr.modules.wav2vec_modules import GumbelVectorQuantizer, compute_mask_indices +from nemo.collections.asr.parts.wav2vec import Wav2VecTransformerEncoder, TransformerEncoder +from nemo.core.classes.common import Serialization + + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +class Spec2VecEncoder(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + self.noisy_spec2vec = cfg.noisy_spec2vec + self.wav2spec = Serialization.from_config_dict(cfg.preprocessor) + self.feature_encoder = Serialization.from_config_dict(cfg.feature_encoder) + self.freeze_feature_encoder = cfg.freeze_feature_encoder + + self.mask_cfg = cfg.masking + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.quantizer = None + + self.n_negatives = cfg.n_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere + + self.final_dim = cfg.final_dim + assert self.final_dim > 0 + self.quantize_targets = cfg.quantizer.quantize_targets + if self.quantize_targets: + assert cfg.quantizer.targets_bottleneck_dim is None + vq_dim = cfg.quantizer.latent_dim if cfg.quantizer.latent_dim > 0 else self.final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.feature_encoder.output_dim, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_q = nn.Linear(vq_dim, self.final_dim) + else: + targets_bottleneck_dim = cfg.quantizer.targets_bottleneck_dim + if targets_bottleneck_dim is None: + self.project_q = nn.Linear(self.feature_encoder.output_dim, self.final_dim) + else: + act_fn_dic = {'relu': nn.ReLU, 'gelu': nn.GELU} + targets_proj_act_fn = cfg.quantizer.targets_bottleneck_act_fn + targets_proj_layers = ( + [nn.Linear(self.feature_encoder.output_dim, targets_bottleneck_dim)] + + ([] if targets_proj_act_fn is None else [act_fn_dic[targets_proj_act_fn]()]) + + [nn.Dropout(cfg.quantizer.targets_bottleneck_dropout)] + + [nn.Linear(targets_bottleneck_dim, self.final_dim)] + + ) + self.project_q = torch.nn.Sequential(*targets_proj_layers) + + self.targets_grad_update_inverval = cfg.targets_grad_update_inverval + self.grad_step_count = 0 + + encoder_embed_dim = cfg.transformer_encoder.encoder.embedding_dim + if cfg.learnable_mask: + if self.noisy_spec2vec: + mask_emb_dim = cfg.preprocessor.features + else: + mask_emb_dim = encoder_embed_dim + self.mask_emb = nn.Parameter(torch.FloatTensor(mask_emb_dim).uniform_()) + else: + self.mask_emb = 0.0 + + if cfg.transformer_encoder.use_pytorch_transformer: + self.encoder = Wav2VecTransformerEncoder(cfg.transformer_encoder) + else: + self.encoder = TransformerEncoder(cfg.transformer_encoder) + + if cfg.final_hidden_dim is None: + self.final_proj = nn.Linear(encoder_embed_dim, self.final_dim) + else: + self.final_proj = torch.nn.Sequential( + nn.Linear(encoder_embed_dim, cfg.final_hidden_dim), + nn.Dropout(cfg.dropout_final), + nn.ReLU(), + nn.Linear(cfg.final_hidden_dim, self.final_dim) + ) + + def forward(self, wavs, wav_lens, *, mask=True, features_only=False) -> tuple: + if self.noisy_spec2vec: + return self.noisy_spec2vec_forward(wavs, wav_lens, mask=mask, features_only=features_only) + else: + return self.spec2vec_forward(wavs, wav_lens, mask=mask, features_only=features_only) + + def spec2vec_forward(self, wavs, wav_lens, *, mask=True, features_only=False) -> tuple: + specs, specs_len = self.wav2spec( + input_signal=wavs, length=wav_lens, + ) + + if self.freeze_feature_encoder: + self.feature_encoder.bn_eval() + with torch.no_grad() if self.freeze_feature_encoder else contextlib.suppress(): + features, feature_lens, _ = self.feature_encoder(specs, specs_len) + # [B, D, T] => [B, T, D] + features = features.transpose(1, 2) + + unmasked_features = None if features_only else features.clone() + + padding_mask = self._create_padding_mask(feature_lens, features.shape[1]) + assert padding_mask.size(1) == features.size(1) + assert padding_mask.ndim == 2 + + features = self.dropout_input(features) + if not features_only: + unmasked_features = self.dropout_features(unmasked_features) + + if mask and self.mask_cfg is not None: + logits, mask_indices, mask_num = self.apply_mask(features, padding_mask) + if features_only: + targets = None + elif mask_indices is not None: + targets = unmasked_features[mask_indices] + if self.mask_cfg.mask_shrink_to_batch_min: + targets = targets.view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + # fake batch dim 1 + targets = targets.view( + 1, -1, unmasked_features.size(-1) + ) + assert targets.shape[1] == sum(mask_num) + else: + targets = unmasked_features + else: + logits = features + targets = None if features_only else unmasked_features + mask_indices = None + mask_num = None + + logits = self.encoder(logits, padding_mask=padding_mask) + + if features_only: + return logits, feature_lens + + prob_ppl_loss, cur_temp = None, None + if self.quantize_targets: + targets, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(targets) + targets = self.project_q(targets) + + if self.negatives_from_everywhere: + assert self.mask_cfg.mask_shrink_to_batch_min + neg_cands, *_ = self.quantizer(unmasked_features) + sampled_negatives, _ = self.sample_negatives(neg_cands, targets.size(1)) + sampled_negatives = self.project_q(sampled_negatives) + else: + if self.mask_cfg.mask_shrink_to_batch_min: + sampled_negatives, _ = self.sample_negatives(targets, targets.size(1)) + else: + sampled_negatives, _ = self.sample_negatives_flat(targets, mask_num) + + if self.codebook_negatives > 0: + assert self.mask_cfg.mask_shrink_to_batch_min + cb_negs = self.quantizer.sample_from_codebook( + targets.size(0) * targets.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, targets.size(0), targets.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + sampled_negatives = torch.cat([sampled_negatives, cb_negs], dim=0) + else: + targets = self.project_q(targets) + prob_ppl = None + + if self.negatives_from_everywhere: + assert self.mask_cfg.mask_shrink_to_batch_min + sampled_negatives, _ = self.sample_negatives(unmasked_features, targets.size(1)) + sampled_negatives = self.project_q(sampled_negatives) + else: + if self.mask_cfg.mask_shrink_to_batch_min: + sampled_negatives, _ = self.sample_negatives(targets, targets.size(1)) + else: + sampled_negatives, _ = self.sample_negatives_flat(targets, mask_num) + + mask_logits = logits[mask_indices] + if self.mask_cfg.mask_shrink_to_batch_min: + mask_logits = mask_logits.view(logits.size(0), -1, logits.size(-1)) + else: + # fake batch dim to 1 + mask_logits = mask_logits.view(1, -1, logits.size(-1)) + + mask_logits = self.final_proj(mask_logits) + + return mask_logits, targets, sampled_negatives, padding_mask, prob_ppl_loss, cur_temp, prob_ppl + + def noisy_spec2vec_forward(self, wavs, wav_lens, *, mask=True, features_only=False) -> tuple: + specs, specs_len = self.wav2spec( + input_signal=wavs, length=wav_lens, + ) + + unmasked_specs = None if features_only else specs.clone() + + if mask: + specs = specs.transpose(1, 2) + specs_mask = self._create_padding_mask(specs_len, specs.shape[1]) + mask_positions = [] + specs, _, _ = self.apply_mask(specs, specs_mask, mask_positions=mask_positions) + specs = specs.transpose(1, 2) + else: + mask_positions = None + + if self.freeze_feature_encoder: + self.feature_encoder.bn_eval() + with torch.no_grad() if self.freeze_feature_encoder else contextlib.suppress(): + features, feature_lens, _ = self.feature_encoder(specs, specs_len) + # [B, D, T] => [B, T, D] + features = features.transpose(1, 2) + + if features_only: + unmasked_features = None + else: + no_targets_grad = False + if self.training: + self.grad_step_count += 1 + if self.targets_grad_update_inverval == 0 or (self.grad_step_count % self.targets_grad_update_inverval != 0): + no_targets_grad = True + if self.targets_grad_update_inverval == 1: + assert not no_targets_grad + + with torch.no_grad() if no_targets_grad else contextlib.suppress(): + with as_eval(self.feature_encoder): + unmasked_features, _, _ = self.feature_encoder(unmasked_specs, specs_len) + unmasked_features = unmasked_features.transpose(1, 2) + + padding_mask = self._create_padding_mask(feature_lens, features.shape[1]) + assert padding_mask.size(1) == features.size(1) + assert padding_mask.ndim == 2 + + features = self.dropout_input(features) + if not features_only: + unmasked_features = self.dropout_features(unmasked_features) + + if mask and not features_only: + # positions to lens + mask_positions = np.array(mask_positions) + 1 + mask_positions = self.feature_encoder.get_subsampled_lens(mask_positions) + # lens to positions + mask_positions = mask_positions - 1 + mask_indices = np.full((unmasked_features.shape[0], unmasked_features.shape[1]), False) + mask_num = [] + for i, mask_position_i in enumerate(mask_positions): + mask_position_i = np.unique(mask_position_i) + mask_num.append(mask_position_i.shape[0]) + mask_indices[i, mask_position_i] = True + mask_indices = torch.from_numpy(mask_indices).to(unmasked_features.device) + targets = unmasked_features[mask_indices] + if self.mask_cfg.mask_shrink_to_batch_min: + targets = targets.view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + # fake batch dim 1 + targets = targets.view( + 1, -1, unmasked_features.size(-1) + ) + else: + mask_indices = None + mask_num = None + targets = None + + logits = self.encoder(features, padding_mask=padding_mask) + + if features_only: + return logits, feature_lens + + prob_ppl_loss, cur_temp = None, None + if self.quantize_targets: + targets, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(targets) + targets = self.project_q(targets) + + if self.negatives_from_everywhere: + assert self.mask_cfg.mask_shrink_to_batch_min + neg_cands, *_ = self.quantizer(unmasked_features) + sampled_negatives, _ = self.sample_negatives(neg_cands, targets.size(1)) + sampled_negatives = self.project_q(sampled_negatives) + else: + if self.mask_cfg.mask_shrink_to_batch_min: + sampled_negatives, _ = self.sample_negatives(targets, targets.size(1)) + else: + sampled_negatives, _ = self.sample_negatives_flat(targets, mask_num) + + if self.codebook_negatives > 0: + assert self.mask_cfg.mask_shrink_to_batch_min + cb_negs = self.quantizer.sample_from_codebook( + targets.size(0) * targets.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, targets.size(0), targets.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + sampled_negatives = torch.cat([sampled_negatives, cb_negs], dim=0) + else: + targets = self.project_q(targets) + prob_ppl = None + + if self.negatives_from_everywhere: + assert self.mask_cfg.mask_shrink_to_batch_min + sampled_negatives, _ = self.sample_negatives(unmasked_features, targets.size(1)) + sampled_negatives = self.project_q(sampled_negatives) + else: + if self.mask_cfg.mask_shrink_to_batch_min: + sampled_negatives, _ = self.sample_negatives(targets, targets.size(1)) + else: + sampled_negatives, _ = self.sample_negatives_flat(targets, mask_num) + + mask_logits = logits[mask_indices] + if self.mask_cfg.mask_shrink_to_batch_min: + mask_logits = mask_logits.view(logits.size(0), -1, logits.size(-1)) + else: + # fake batch dim to 1 + mask_logits = mask_logits.view(1, -1, logits.size(-1)) + + mask_logits = self.final_proj(mask_logits) + + return mask_logits, targets, sampled_negatives, padding_mask, prob_ppl_loss, cur_temp, prob_ppl + + def extract_features(self, source, audio_lengths, mask=False): + padding_mask = self._create_padding_mask(audio_lengths, max_len=source.shape[1]) + return self(source=source, padding_mask=padding_mask, mask=mask, features_only=True) + + def remove_pretraining_modules(self): + self.quantizer = None + self.project_q = None + self.final_proj = None + self.dropout_features = None + + def _update_quantizer_temp(self, global_step): + if self.quantize_targets: + self.quantizer.set_num_updates(global_step) + + def apply_mask(self, x, padding_mask, mask_positions=None): + B, T, C = x.shape + if self.mask_cfg.mask_prob > 0: + mask_indices, mask_num = compute_mask_indices( + (B, T), + padding_mask, + self.mask_cfg.mask_prob, + self.mask_cfg.mask_length, + self.mask_cfg.mask_type, + self.mask_cfg.mask_other, + min_masks=2, + no_overlap=self.mask_cfg.no_mask_overlap, + min_space=self.mask_cfg.mask_min_space, + shrink_to_batch_min=self.mask_cfg.mask_shrink_to_batch_min, + mask_positions=mask_positions + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + mask_emb = self.mask_emb + if isinstance(mask_emb, torch.Tensor): + mask_emb = mask_emb.type_as(x) + x[mask_indices] = mask_emb + else: + mask_indices = None + + if self.mask_cfg.mask_channel_prob > 0: + # assert self.mask_cfg.mask_shrink_to_batch_min + mask_channel_indices, _ = compute_mask_indices( + (B, C), + None, + self.mask_cfg.mask_channel_prob, + self.mask_cfg.mask_channel_length, + self.mask_cfg.mask_channel_type, + self.mask_cfg.mask_channel_other, + no_overlap=self.mask_cfg.no_mask_channel_overlap, + min_space=self.mask_cfg.mask_channel_min_space, + shrink_to_batch_min=self.mask_cfg.mask_shrink_to_batch_min, + ) + mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) + x[mask_channel_indices] = 0 + + assert len(mask_num) == B + return x, mask_indices, mask_num + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz, tsz, fsz}" + + if self.n_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten() + + cross_neg_idxs = torch.randint( + low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def sample_negatives_flat(self, y, nums): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + assert bsz == 1 and tsz == sum(nums) # fake batch dim + y = y.view(-1, fsz) # BTC => (BxT)C + + # cross_high = tsz * bsz + + neg_idxs_l = [] + idx_start = 0 + with torch.no_grad(): + for i, num_i in enumerate(nums): + assert num_i > 1, f"{bsz, tsz, fsz}" + + assert self.n_negatives > 0 + tszs_i = buffered_arange(num_i).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + high_i = num_i + neg_idxs_i = torch.randint(low=0, high=high_i - 1, size=(self.n_negatives * num_i,)) + neg_idxs_i[neg_idxs_i >= tszs_i] += 1 + + neg_idxs_i += idx_start + idx_start += num_i + + neg_idxs_l.append(neg_idxs_i) + + assert self.cross_sample_negatives == 0 + + neg_idxs = torch.cat(neg_idxs_l) + assert neg_idxs.ndim == 1 + + negs = y[neg_idxs] + negs = negs.view(bsz, sum(nums), self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def _create_padding_mask(self, audio_lengths, max_len): + # Broadcast to vectorize creating the padding mask + padding_mask = torch.arange(max_len, device=audio_lengths.device) + padding_mask = padding_mask.expand(len(audio_lengths), max_len) < audio_lengths.unsqueeze(1) + # Negate to false where no padding + padding_mask = ~padding_mask + return padding_mask + + +@contextlib.contextmanager +def as_eval(module): + training_state = module.training + module.eval() + + try: + yield + finally: + module.train(training_state) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_pretrain.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..08c4d7ffc5544722e3fb2649a46c0666fa7f06bd --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/spec2vec_pretrain.py @@ -0,0 +1,166 @@ +import logging +from math import ceil +from typing import Dict, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.wav2vecloss import Wav2VecLoss +from nemo.collections.asr.models.spec2vec.spec2vec_model import Spec2VecEncoder +from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.core import ModelPT +from nemo.core.classes.common import PretrainedModelInfo + + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +class Spec2VecPretrainModel(ModelPT): + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + super().__init__(cfg=cfg, trainer=trainer) + + if isinstance(cfg, dict): + cfg = OmegaConf.create(cfg) + elif not isinstance(cfg, DictConfig): + raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig") + + self.spec2vec_encoder = Spec2VecEncoder(cfg.spec2vec_encoder) + + self.loss = Wav2VecLoss( + feature_loss_weight=0.0, + prob_ppl_weight=cfg.loss.prob_ppl_weight if self.spec2vec_encoder.quantize_targets else 0.0, + logit_temp=cfg.logit_temp, + ) + + self._prev_log_step = -1 + + def training_step(self, batch, batch_idx): + loss, contrastive_loss, prob_ppl_loss, cur_temp, prob_ppl, _ = self._step(batch) + + if self.global_step > self._prev_log_step: + self._prev_log_step = self.global_step + tensorboard = self.logger.experiment + tensorboard.add_scalar('loss', loss, self.global_step) + tensorboard.add_scalar('contrastive_loss', contrastive_loss, self.global_step) + if prob_ppl_loss is not None: + tensorboard.add_scalar('prob_ppl_loss', prob_ppl_loss, self.global_step) + tensorboard.add_scalar('temp', cur_temp, self.global_step) + tensorboard.add_scalar('prob_ppl', prob_ppl, self.global_step) + tensorboard.add_scalar('learning_rate', self._optimizer.param_groups[0]['lr'], self.global_step) + return {'loss': loss} + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + loss, contrastive_loss, prob_ppl_loss, _, _, accuracy = self._step(batch) + self.log('val_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) + self.log('val_accuracy', accuracy, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False) + + def test_step(self, batch, batch_idx, dataloader_idx=0): + loss, contrastive_loss, prob_ppl_loss, _, _, accuracy = self._step(batch) + self.log('test_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) + self.log('test_accuracy', accuracy, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False) + + def _step(self, batch): + audio_signal, audio_lengths = batch + + self.spec2vec_encoder._update_quantizer_temp(self.trainer.global_step) + logits, targets, sampled_negatives, _, prob_ppl_loss, cur_temp, prob_ppl = self( + source=audio_signal, source_lens=audio_lengths + ) + loss, contrastive_loss, _, prob_ppl_loss, accuracy = self.loss( + logits=logits, + targets=targets, + negatives=sampled_negatives, + prob_ppl_loss=prob_ppl_loss, + feature_loss=None, + compute_accuracy=not self.training + ) + return loss, contrastive_loss, prob_ppl_loss, cur_temp, prob_ppl, accuracy + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + return None + + def forward(self, source, source_lens, mask=True, features_only=False) -> tuple: + return self.spec2vec_encoder(source, source_lens, mask=mask, features_only=features_only) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor']) + else: + augmentor = None + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_text_dataset.get_audio_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/unit_ctc_finetune.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/unit_ctc_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..f3831475a0cf6495b635ffad3ea2c67f5b1d9e7f --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/unit_ctc_finetune.py @@ -0,0 +1,710 @@ +"""Modified from .vq_ctc_finetune.py and .ctc_finetune.py""" +import contextlib +import copy +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf, open_dict, ListConfig +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer import WER, WER_phone +from nemo.collections.asr.metrics.wer_bpe import WERBPE +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.parts.perturb import process_augmentations, RandomNoisePerturbation, AudioAugmentor +from nemo.utils import logging + + +class UnitCTCFinetuneModel(ASRModel): + """Todo: should modify to remove the vector quantization""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + self.label_type = cfg.label_type + assert self.label_type in ['char', 'phone', 'unit', 'bpe'] + if self.label_type == 'bpe': + self.use_bpe = True + assert cfg.tokenizer is not None + from nemo.collections.asr.parts.mixins import ASRBPEMixin + self.bpe = ASRBPEMixin() + self.bpe._setup_tokenizer(cfg.tokenizer, register_artifact=False) + self.tokenizer = self.bpe.tokenizer + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + assert len(cfg.decoder.vocabulary) == 0 + with open_dict(cfg): + cfg.decoder.vocabulary = ListConfig(list(vocabulary.values())) + else: + self.use_bpe = False + assert cfg.tokenizer is None + self.add_end_space = cfg.add_end_space + self.lang = cfg.lang + + super().__init__(cfg=cfg, trainer=trainer) + + if self._cfg.encoder_type == 'spec2vec': + from nemo.collections.asr.models.spec2vec.spec2vec_model import Spec2VecEncoder + self.encoder = Spec2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'spec2vec_encoder.' + elif self._cfg.encoder_type == 'st': + from nemo.collections.asr.models.st2vec.st2vec_model import ST2VecEncoder + self.encoder = ST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + elif self._cfg.encoder_type == 'feat_st': + from nemo.collections.asr.models.st2vec.st2vec_model import FeatST2VecEncoder + self.encoder = FeatST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + else: + assert self._cfg.encoder_type == 'wav2vec' + from nemo.collections.asr.modules.wav2vec_encoder import Wav2VecEncoderModel + self.encoder = Wav2VecEncoderModel(self._cfg.encoder) + encoder_param_prefix = None + if cfg.pretrain_chkpt_path is not None: + self.init_encoder_from_pretrain_model(self.encoder, encoder_param_prefix, cfg.pretrain_chkpt_path) + if self._cfg.encoder_type == 'st': + self.encoder.remove_pretraining_modules(use_teacher_encoder=self._cfg.use_teacher_encoder) + else: + self.encoder.remove_pretraining_modules() + + self.decoder = self.from_config_dict(self._cfg.decoder) + + self.freeze_finetune_updates = self._cfg.freeze_finetune_updates + + self.loss = CTCLoss( + blank_id=self.decoder.blank_idx, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + # Setup metric objects + if self.use_bpe: + self._wer = WERBPE( + tokenizer=self.tokenizer, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + lang=self.lang, + ) + else: + if self.label_type in ['phone', 'unit']: + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + strip_end_space=self.add_end_space, + ) + + @torch.no_grad() + def transcribe( + self, paths2audio_files: List[str], batch_size: int = 4, logprobs=False, return_hypotheses: bool = False + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + # We will store transcriptions here + hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + logits, logits_len, greedy_predictions = self( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device), + global_step=None + ) + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + hypotheses.append(logits[idx][: logits_len[idx]]) + else: + current_hypotheses = self._wer.ctc_decoder_predictions_tensor( + greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses, + ) + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += current_hypotheses + + del greedy_predictions + del logits + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + return hypotheses + + def change_vocabulary(self, new_vocabulary: List[str]): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + assert not self.use_bpe + if self.decoder.vocabulary == new_vocabulary: + logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.decoder + self.decoder = self.from_config_dict(new_decoder_config) + del self.loss + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + if self.label_type in ['phone', 'unit']: + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict], noise_perturb_config): + if noise_perturb_config is not None: + noise_perturb = RandomNoisePerturbation(**noise_perturb_config) + augmentor = AudioAugmentor(perturbations=[(1.0, noise_perturb)]) + else: + augmentor = None + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + if self.add_end_space: + config['parser_add_end_space'] = self.add_end_space + + if self.use_bpe: + dataset = audio_to_text_dataset.get_bpe_dataset(config=config, tokenizer=self.tokenizer, + augmentor=augmentor) + else: + if self.label_type == 'char': + dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) + elif self.label_type in ['phone', 'unit']: + dataset = audio_to_text_dataset.get_phone_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config, + noise_perturb_config=self._cfg['noise_perturb']) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, noise_perturb_config=None) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config, noise_perturb_config=None) + + def optim_param_groups(self): + return [{'params': self.encoder.parameters(), 'weight_decay': 0.0}, + {'params': self.decoder.parameters()}] + + # def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None): + # """ + # Prepares an optimizer from a string name and its optional config parameters. + # + # Args: + # optim_config: A dictionary containing the following keys: + # + # * "lr": mandatory key for learning rate. Will raise ValueError if not provided. + # * "optimizer": string name pointing to one of the available optimizers in the registry. \ + # If not provided, defaults to "adam". + # * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \ + # The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \ + # will be built and supplied to instantiate the optimizer. + # """ + # ## modifed by me!! + # # If config was not explicitly passed to us + # if optim_config is None: + # # See if internal config has `optim` namespace + # if self._cfg is not None and hasattr(self._cfg, 'optim'): + # optim_config = self._cfg.optim + # + # # If config is still None, or internal config has no Optim, return without instantiation + # if optim_config is None: + # logging.info('No optimizer config provided, therefore no optimizer was created') + # return + # + # else: + # # Preserve the configuration + # if not isinstance(optim_config, DictConfig): + # optim_config = OmegaConf.create(optim_config) + # + # # See if internal config has `optim` namespace before preservation + # if self._cfg is not None and hasattr(self._cfg, 'optim'): + # if self._cfg.optim is None: + # self._cfg.optim = copy.deepcopy(optim_config) + # else: + # with open_dict(self._cfg.optim): + # self._cfg.optim = copy.deepcopy(optim_config) + # + # # Setup optimizer and scheduler + # if optim_config is not None and isinstance(optim_config, DictConfig): + # optim_config = OmegaConf.to_container(optim_config, resolve=True) + # + # if 'sched' in optim_config and optim_config[ + # 'sched'] is not None and self._trainer is not None: ## this line modified! + # if not isinstance(self._trainer.accumulate_grad_batches, int): + # raise ValueError("We do not currently support gradient acculumation that is not an integer.") + # if self._trainer.max_steps is None: + # # Store information needed to calculate max_steps + # optim_config['sched']['t_max_epochs'] = self._trainer.max_epochs + # optim_config['sched']['t_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches + # optim_config['sched']['t_limit_train_batches'] = self._trainer.limit_train_batches + # if self._trainer.distributed_backend is None: + # optim_config['sched']['t_num_workers'] = self._trainer.num_gpus or 1 + # elif self._trainer.distributed_backend == "ddp_cpu": + # optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes + # elif self._trainer.distributed_backend == "ddp": + # optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + # else: + # logging.warning( + # f"The lightning trainer received accelerator: {self._trainer.distributed_backend}. We " + # "recommend to use 'ddp' instead." + # ) + # optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + # else: + # optim_config['sched']['max_steps'] = self._trainer.max_steps + # + # # Force into DictConfig from nested structure + # optim_config = OmegaConf.create(optim_config) + # # Get back nested dict so we its mutable + # optim_config = OmegaConf.to_container(optim_config, resolve=True) + # + # # Extract scheduler config if inside optimizer config + # if 'sched' in optim_config: + # scheduler_config = optim_config.pop('sched') + # else: + # scheduler_config = None + # + # # Check if caller provided optimizer name, default to Adam otherwise + # optimizer_cls = optim_config.get('_target_', None) + # + # if optimizer_cls is None: + # # Try to get optimizer name for dynamic resolution, defaulting to Adam + # optimizer_name = optim_config.get('name', 'adam') + # else: + # if inspect.isclass(optimizer_cls): + # optimizer_name = optimizer_cls.__name__.lower() + # else: + # # resolve the class name (lowercase) from the class path if not provided + # optimizer_name = optimizer_cls.split(".")[-1].lower() + # + # # We are guarenteed to have lr since it is required by the argparser + # # But maybe user forgot to pass it to this function + # lr = optim_config.get('lr', None) + # + # # Check if caller has optimizer kwargs, default to empty dictionary + # if 'args' in optim_config: + # optimizer_args = optim_config.pop('args') + # optimizer_args = optim.parse_optimizer_args(optimizer_name, optimizer_args) + # else: + # optimizer_args = copy.deepcopy(optim_config) + # + # # Remove extra parameters from optimizer_args nest + # # Assume all other parameters are to be passed into optimizer constructor + # optimizer_args.pop('name', None) + # optimizer_args.pop('cls', None) + # optimizer_args.pop('lr', None) + # + # # Adaptive schedulers don't need `lr` + # if lr is not None: + # optimizer_args['lr'] = lr + # + # # Actually instantiate the optimizer + # if optimizer_cls is not None: + # if inspect.isclass(optimizer_cls): + # optimizer = optimizer_cls(self.optim_param_groups(), **optimizer_args) + # logging.info("Optimizer config = %s", str(optimizer)) + # + # self._optimizer = optimizer + # + # else: + # # Attempt class path resolution + # try: + # optimizer_cls = OmegaConf.create({'_target_': optimizer_cls}) + # if lr is not None: + # optimizer_config = {'lr': lr} + # else: + # optimizer_config = {} + # optimizer_config.update(optimizer_args) + # + # optimizer_instance = hydra.utils.instantiate( + # optimizer_cls, self.optim_param_groups(), **optimizer_config + # ) # type: DictConfig + # + # logging.info("Optimizer config = %s", str(optimizer_instance)) + # + # self._optimizer = optimizer_instance + # + # except Exception as e: + # logging.error( + # "Could not instantiate class path - {} with kwargs {}".format( + # optimizer_cls, str(optimizer_config) + # ) + # ) + # raise e + # + # else: + # optimizer = optim.get_optimizer(optimizer_name) + # optimizer = optimizer(self.optim_param_groups(), **optimizer_args) + # + # logging.info("Optimizer config = %s", str(optimizer)) + # + # self._optimizer = optimizer + # + # # Try to instantiate scheduler for optimizer + # self._scheduler = prepare_lr_scheduler( + # optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl + # ) + # + # # Return the optimizer with/without scheduler + # # This return allows multiple optimizers or schedulers to be created + # return self._optimizer, self._scheduler + + def forward(self, input_signal, input_signal_length, global_step): + """ + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + """ + ft = False if global_step is None else self.freeze_finetune_updates <= global_step + with torch.no_grad() if not ft else contextlib.suppress(): + encoded, encoded_len = self.encoder(input_signal, input_signal_length, None, None, + mask=self.training, features_only=True) + + # [B, T, D] => [B, D, T] + encoded = encoded.transpose(1, 2) + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = encoded_len.max() + if encoded.shape[2] != max_output_len: + encoded = encoded.narrow(dim=2, start=0, length=max_output_len).contiguous() + + logits, encoded_len = self.decoder(encoder_output=encoded, lens=encoded_len, log_prob=False) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + with torch.no_grad(): + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + + return log_probs, encoded_len, greedy_predictions, logits + + def training_step(self, batch, batch_nb): + signal, signal_len, transcript, transcript_len = batch + log_probs, encoded_len, predictions, _ = self(input_signal=signal, input_signal_length=signal_len, + global_step=self.trainer.global_step) + loss = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + tensorboard_logs = {'train_loss': loss, 'learning_rate': self._optimizer.param_groups[0]['lr']} + + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, decode_results=None): + signal, signal_len, transcript, transcript_len = batch + with torch.no_grad(): + log_probs, encoded_len, predictions, logits = self(input_signal=signal, input_signal_length=signal_len, + global_step=None) + loss = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + self._wer.update( + predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + log_prediction=batch_idx < 3, decode_results=decode_results) + wer, wer_num, wer_denom = self._wer.compute() + return { + 'val_loss': loss, + 'val_wer_num': wer_num, + 'val_wer_denom': wer_denom, + 'val_wer': wer, + 'val_logprob': log_probs.cpu().numpy(), + 'val_logprob_len': encoded_len.cpu().numpy(), + 'val_logits': logits.cpu().numpy(), + } + + def test_step(self, batch, batch_idx, dataloader_idx=0): + decode_results = {} + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx, decode_results=decode_results) + test_logs = { + 'test_loss': logs['val_loss'], + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + 'test_wer': logs['val_wer'], + 'test_references': decode_results['references'], + 'test_hypotheses': decode_results['hypotheses'], + 'test_logprob': logs['val_logprob'], + 'test_logprob_len': logs['val_logprob_len'], + 'test_logits': logs['val_logits'], + } + return test_logs + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + assert not self.use_bpe + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'labels': self.decoder.vocabulary, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': True, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), + noise_perturb_config=None) + return temporary_datalayer + + @classmethod + def init_encoder_from_pretrain_model( + cls, + encoder, + encoder_param_prefix, + checkpoint_path, + *, + map_location=None, + strict: bool = True): + try: + cls._set_model_restore_state(is_being_restored=True) + + from pytorch_lightning.utilities.cloud_io import load as pl_load + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # for past checkpoint need to add the new key + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint + pretrain_cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + + # give model a chance to load something + # model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + if encoder_param_prefix is not None: + encoder_state = {k[len(encoder_param_prefix):]: v for k, v in checkpoint['state_dict'].items() if + k.startswith(encoder_param_prefix)} + else: + encoder_state = checkpoint['state_dict'] + encoder.load_state_dict(encoder_state, strict=strict) + finally: + cls._set_model_restore_state(is_being_restored=False) + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom} + references = itertools.chain.from_iterable([x['test_references'] for x in outputs]) + hypotheses = itertools.chain.from_iterable([x['test_hypotheses'] for x in outputs]) + test_logprob = [x['test_logprob'] for x in outputs] + test_logprob_len = [x['test_logprob_len'] for x in outputs] + test_logits = [x['test_logits'] for x in outputs] + return {'test_loss': val_loss_mean, 'log': tensorboard_logs, 'decode_results': (references, hypotheses), + 'test_logprob': test_logprob, 'test_logprob_len': test_logprob_len, 'test_logits': test_logits} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..d962c2af05701a67d8ddcc42ab74e56361742d55 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune.py @@ -0,0 +1,801 @@ +"""Modified from .ctc_finetune.py""" +import contextlib +import copy +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf, open_dict, ListConfig +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer import WER, WER_phone +from nemo.collections.asr.metrics.wer_bpe import WERBPE +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.parts.perturb import process_augmentations, RandomNoisePerturbation, AudioAugmentor +from nemo.utils import logging +from nemo.collections.asr.modules.wav2vec_modules import compute_mask_indices, GumbelVectorQuantizer +from nemo.core.classes.modelPT import inspect, optim, hydra, prepare_lr_scheduler + +import torch.nn as nn +from my_scripts.fsq import FSQ + + +class VQCTCFinetuneModel(ASRModel): + """Todo: should modify""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + self.label_type = cfg.label_type + assert self.label_type in ['char', 'phone', 'bpe'] + if self.label_type == 'bpe': + self.use_bpe = True + assert cfg.tokenizer is not None + from nemo.collections.asr.parts.mixins import ASRBPEMixin + self.bpe = ASRBPEMixin() + self.bpe._setup_tokenizer(cfg.tokenizer, register_artifact=False) + self.tokenizer = self.bpe.tokenizer + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + assert len(cfg.decoder.vocabulary) == 0 + with open_dict(cfg): + cfg.decoder.vocabulary = ListConfig(list(vocabulary.values())) + else: + self.use_bpe = False + assert cfg.tokenizer is None + self.add_end_space = cfg.add_end_space + self.lang = cfg.lang + + super().__init__(cfg=cfg, trainer=trainer) + + if self._cfg.encoder_type == 'spec2vec': + from nemo.collections.asr.models.spec2vec.spec2vec_model import Spec2VecEncoder + self.encoder = Spec2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'spec2vec_encoder.' + elif self._cfg.encoder_type == 'st': + from nemo.collections.asr.models.st2vec.st2vec_model import ST2VecEncoder + self.encoder = ST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + elif self._cfg.encoder_type == 'feat_st': + from nemo.collections.asr.models.st2vec.st2vec_model import FeatST2VecEncoder + self.encoder = FeatST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + else: + assert self._cfg.encoder_type == 'wav2vec' + from nemo.collections.asr.modules.wav2vec_encoder import Wav2VecEncoderModel + self.encoder = Wav2VecEncoderModel(self._cfg.encoder) + encoder_param_prefix = None + if cfg.pretrain_chkpt_path is not None: + self.init_encoder_from_pretrain_model(self.encoder, encoder_param_prefix, cfg.pretrain_chkpt_path) + if self._cfg.encoder_type == 'st': + self.encoder.remove_pretraining_modules(use_teacher_encoder=self._cfg.use_teacher_encoder) + else: + self.encoder.remove_pretraining_modules() + + # # quantizer_input_dim = self.encoder.feature_encoder.output_dim + # self.pre_logits_bottleneck_layer = nn.Linear(self.encoder.feature_encoder.output_dim, cfg.quantizer.pre_logits_dim) + # + # self.quantizer = GumbelVectorQuantizer( + # dim=cfg.quantizer.pre_logits_dim, # quantizer_input_dim, + # num_vars=cfg.quantizer.latent_vars, + # temp=cfg.quantizer.latent_temp, + # groups=cfg.quantizer.latent_groups, + # combine_groups=False, + # vq_dim=cfg.quantizer.latent_dim, + # time_first=True, + # ) + + levels = cfg.quantizer.levels + self.pre_quant = nn.Linear(self.encoder.feature_encoder.output_dim, len(levels)) + self.quantizer = FSQ(levels=levels, l2_norm=cfg.quantizer.l2_norm, batch_norm=cfg.quantizer.batch_norm) + + self.decoder = self.from_config_dict(self._cfg.decoder) + + self.freeze_finetune_updates = self._cfg.freeze_finetune_updates + + self.ctc_loss = CTCLoss( + blank_id=self.decoder.blank_idx, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + self.quant_ppl_loss_weight = cfg.quant_ppl_loss_weight + + # Setup metric objects + if self.use_bpe: + self._wer = WERBPE( + tokenizer=self.tokenizer, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + lang=self.lang, + ) + else: + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + strip_end_space=self.add_end_space, + ) + + @torch.no_grad() + def transcribe( + self, paths2audio_files: List[str], batch_size: int = 4, logprobs=False, return_hypotheses: bool = False + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + # We will store transcriptions here + hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + # dither_value = self.preprocessor.featurizer.dither + # pad_to_value = self.preprocessor.featurizer.pad_to + dither_value = self.encoder.wav2spec.featurizer.dither + pad_to_value = self.encoder.wav2spec.featurizer.pad_to + + try: + # self.preprocessor.featurizer.dither = 0.0 + # self.preprocessor.featurizer.pad_to = 0 + self.encoder.wav2spec.featurizer.dither = 0.0 + self.encoder.wav2spec.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + # logits, logits_len, greedy_predictions = self( + # input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + # ) + + # log_probs, logits_len, greedy_predictions, logits, prob_ppl_loss, cur_temp, prob_ppl = \ + # self(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device), + # global_step=self.trainer.global_step) + + log_probs, logits_len, greedy_predictions, logits = \ + self(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device), + global_step=self.trainer.global_step) + + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + hypotheses.append(logits[idx][: logits_len[idx]]) + else: + current_hypotheses = self._wer.ctc_decoder_predictions_tensor( + greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses, + ) + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += current_hypotheses + + del greedy_predictions + del logits + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + # self.preprocessor.featurizer.dither = dither_value + # self.preprocessor.featurizer.pad_to = pad_to_value + self.encoder.wav2spec.featurizer.dither = dither_value + self.encoder.wav2spec.featurizer.pad_to = pad_to_value + return hypotheses + + def change_vocabulary(self, new_vocabulary: List[str]): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + assert not self.use_bpe + if self.decoder.vocabulary == new_vocabulary: + logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.decoder + self.decoder = self.from_config_dict(new_decoder_config) + del self.ctc_loss + self.ctc_loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict], noise_perturb_config): + if noise_perturb_config is not None: + noise_perturb = RandomNoisePerturbation(**noise_perturb_config) + augmentor = AudioAugmentor(perturbations=[(1.0, noise_perturb)]) + else: + augmentor = None + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + if self.add_end_space: + config['parser_add_end_space'] = self.add_end_space + + if self.use_bpe: + dataset = audio_to_text_dataset.get_bpe_dataset(config=config, tokenizer=self.tokenizer, + augmentor=augmentor) + else: + if self.label_type == 'char': + dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) + elif self.label_type == 'phone': + dataset = audio_to_text_dataset.get_phone_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config, + noise_perturb_config=self._cfg['noise_perturb']) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, noise_perturb_config=None) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config, noise_perturb_config=None) + + # def optim_param_groups(self): + # return [{'params': self.encoder.parameters(), 'weight_decay': 0.0}, + # {'params': self.decoder.parameters()}] + + def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None): + """ + Prepares an optimizer from a string name and its optional config parameters. + + Args: + optim_config: A dictionary containing the following keys: + + * "lr": mandatory key for learning rate. Will raise ValueError if not provided. + * "optimizer": string name pointing to one of the available optimizers in the registry. \ + If not provided, defaults to "adam". + * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \ + The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \ + will be built and supplied to instantiate the optimizer. + """ + ## modifed by me!! + # If config was not explicitly passed to us + if optim_config is None: + # See if internal config has `optim` namespace + if self._cfg is not None and hasattr(self._cfg, 'optim'): + optim_config = self._cfg.optim + + # If config is still None, or internal config has no Optim, return without instantiation + if optim_config is None: + logging.info('No optimizer config provided, therefore no optimizer was created') + return + + else: + # Preserve the configuration + if not isinstance(optim_config, DictConfig): + optim_config = OmegaConf.create(optim_config) + + # See if internal config has `optim` namespace before preservation + if self._cfg is not None and hasattr(self._cfg, 'optim'): + if self._cfg.optim is None: + self._cfg.optim = copy.deepcopy(optim_config) + else: + with open_dict(self._cfg.optim): + self._cfg.optim = copy.deepcopy(optim_config) + + # Setup optimizer and scheduler + if optim_config is not None and isinstance(optim_config, DictConfig): + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + if 'sched' in optim_config and optim_config[ + 'sched'] is not None and self._trainer is not None: ## this line modified! + if not isinstance(self._trainer.accumulate_grad_batches, int): + raise ValueError("We do not currently support gradient acculumation that is not an integer.") + if self._trainer.max_steps is None: + # Store information needed to calculate max_steps + optim_config['sched']['t_max_epochs'] = self._trainer.max_epochs + optim_config['sched']['t_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches + optim_config['sched']['t_limit_train_batches'] = self._trainer.limit_train_batches + if self._trainer.distributed_backend is None: + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus or 1 + elif self._trainer.distributed_backend == "ddp_cpu": + optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes + elif self._trainer.distributed_backend == "ddp": + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + else: + logging.warning( + f"The lightning trainer received accelerator: {self._trainer.distributed_backend}. We " + "recommend to use 'ddp' instead." + ) + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + else: + optim_config['sched']['max_steps'] = self._trainer.max_steps + + # Force into DictConfig from nested structure + optim_config = OmegaConf.create(optim_config) + # Get back nested dict so we its mutable + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + # Extract scheduler config if inside optimizer config + if 'sched' in optim_config: + scheduler_config = optim_config.pop('sched') + else: + scheduler_config = None + + # Check if caller provided optimizer name, default to Adam otherwise + optimizer_cls = optim_config.get('_target_', None) + + if optimizer_cls is None: + # Try to get optimizer name for dynamic resolution, defaulting to Adam + optimizer_name = optim_config.get('name', 'adam') + else: + if inspect.isclass(optimizer_cls): + optimizer_name = optimizer_cls.__name__.lower() + else: + # resolve the class name (lowercase) from the class path if not provided + optimizer_name = optimizer_cls.split(".")[-1].lower() + + # We are guarenteed to have lr since it is required by the argparser + # But maybe user forgot to pass it to this function + lr = optim_config.get('lr', None) + + # Check if caller has optimizer kwargs, default to empty dictionary + if 'args' in optim_config: + optimizer_args = optim_config.pop('args') + optimizer_args = optim.parse_optimizer_args(optimizer_name, optimizer_args) + else: + optimizer_args = copy.deepcopy(optim_config) + + # Remove extra parameters from optimizer_args nest + # Assume all other parameters are to be passed into optimizer constructor + optimizer_args.pop('name', None) + optimizer_args.pop('cls', None) + optimizer_args.pop('lr', None) + + # Adaptive schedulers don't need `lr` + if lr is not None: + optimizer_args['lr'] = lr + + # Actually instantiate the optimizer + if optimizer_cls is not None: + if inspect.isclass(optimizer_cls): + optimizer = optimizer_cls(self.optim_param_groups(), **optimizer_args) + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + else: + # Attempt class path resolution + try: + optimizer_cls = OmegaConf.create({'_target_': optimizer_cls}) + if lr is not None: + optimizer_config = {'lr': lr} + else: + optimizer_config = {} + optimizer_config.update(optimizer_args) + + optimizer_instance = hydra.utils.instantiate( + optimizer_cls, self.optim_param_groups(), **optimizer_config + ) # type: DictConfig + + logging.info("Optimizer config = %s", str(optimizer_instance)) + + self._optimizer = optimizer_instance + + except Exception as e: + logging.error( + "Could not instantiate class path - {} with kwargs {}".format( + optimizer_cls, str(optimizer_config) + ) + ) + raise e + + else: + optimizer = optim.get_optimizer(optimizer_name) + optimizer = optimizer(self.optim_param_groups(), **optimizer_args) + + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + # Try to instantiate scheduler for optimizer + self._scheduler = prepare_lr_scheduler( + optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl + ) + + # Return the optimizer with/without scheduler + # This return allows multiple optimizers or schedulers to be created + return self._optimizer, self._scheduler + + def forward(self, input_signal, input_signal_length, global_step): + """ + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + """ + ft = False if global_step is None else self.freeze_finetune_updates <= global_step + with torch.no_grad() if not ft else contextlib.suppress(): + encoded, encoded_len = self.encoder(input_signal, input_signal_length, None, None, + mask=self.training, features_only=True) + + # [B, T, D] => [B, D, T] + encoded = encoded.transpose(1, 2) + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = encoded_len.max() + if encoded.shape[2] != max_output_len: + encoded = encoded.narrow(dim=2, start=0, length=max_output_len).contiguous() + + # [B, D, T] => [B, T, D] + encoded = encoded.transpose(1, 2) + + # self.quantizer.set_num_updates(global_step) + # + # encoded = self.pre_logits_bottleneck_layer(encoded) + # quantized_encoded, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(encoded) + + encoded = self.pre_quant(encoded) + # quantized_encoded = self.quantizer.quantize(encoded) + quantized_encoded = self.quantizer(encoded) + + # Todo: should exclude the units that are beyond the sequence length from quantization optimization + # [B, T, D] => [B, D, T] + quantized_encoded = quantized_encoded.transpose(1, 2) + + logits, encoded_len = self.decoder(encoder_output=quantized_encoded, lens=encoded_len, log_prob=False) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + with torch.no_grad(): + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + + # return log_probs, encoded_len, greedy_predictions, logits, prob_ppl_loss, cur_temp, prob_ppl + return log_probs, encoded_len, greedy_predictions, logits + + def training_step(self, batch, batch_nb): + signal, signal_len, transcript, transcript_len = batch + + # log_probs, encoded_len, predictions, _, prob_ppl_loss, cur_temp, prob_ppl = \ + # self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + log_probs, encoded_len, predictions, _ = \ + self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + # if self.quant_ppl_loss_weight != 0: + # weighted_prob_ppl_loss = self.quant_ppl_loss_weight * prob_ppl_loss + # # Todo: check, this is mean value of batch + # loss = ctc_loss + weighted_prob_ppl_loss + # else: + # loss = ctc_loss + loss = ctc_loss + + # tensorboard_logs = {'train_loss': loss, 'ctc_loss': ctc_loss, 'weighted_prob_ppl_loss': weighted_prob_ppl_loss, + # 'prob_ppl_loss': prob_ppl_loss, 'cur_temp': cur_temp, 'prob_ppl': prob_ppl, + # 'learning_rate': self._optimizer.param_groups[0]['lr']} + tensorboard_logs = {'train_loss': loss, 'ctc_loss': ctc_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr']} + + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, decode_results=None): + signal, signal_len, transcript, transcript_len = batch + + with torch.no_grad(): + # log_probs, encoded_len, predictions, logits, prob_ppl_loss, cur_temp, prob_ppl = \ + # self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + log_probs, encoded_len, predictions, logits, = \ + self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + # if self.quant_ppl_loss_weight != 0: + # weighted_prob_ppl_loss = self.quant_ppl_loss_weight * prob_ppl_loss + # # Todo: check, this is mean value of batch + # loss = ctc_loss + weighted_prob_ppl_loss + # else: + # loss = ctc_loss + loss = ctc_loss + + self._wer.update( + predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + log_prediction=batch_idx < 3, decode_results=decode_results) + wer, wer_num, wer_denom = self._wer.compute() + + return { + 'val_loss': loss, + 'val_wer_num': wer_num, + 'val_wer_denom': wer_denom, + 'val_wer': wer, + 'val_logprob': log_probs.cpu().numpy(), + 'val_logprob_len': encoded_len.cpu().numpy(), + 'val_logits': logits.cpu().numpy(), + } + + def test_step(self, batch, batch_idx, dataloader_idx=0): + decode_results = {} + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx, decode_results=decode_results) + test_logs = { + 'test_loss': logs['val_loss'], + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + 'test_wer': logs['val_wer'], + 'test_references': decode_results['references'], + 'test_hypotheses': decode_results['hypotheses'], + 'test_logprob': logs['val_logprob'], + 'test_logprob_len': logs['val_logprob_len'], + 'test_logits': logs['val_logits'], + } + return test_logs + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + assert not self.use_bpe + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + # 'sample_rate': self.preprocessor._sample_rate, + 'sample_rate': self.encoder.wav2spec._sample_rate, + 'labels': self.decoder.vocabulary, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': True, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), + noise_perturb_config=None) + return temporary_datalayer + + @classmethod + def init_encoder_from_pretrain_model( + cls, + encoder, + encoder_param_prefix, + checkpoint_path, + *, + map_location=None, + strict: bool = True): + try: + cls._set_model_restore_state(is_being_restored=True) + + from pytorch_lightning.utilities.cloud_io import load as pl_load + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # for past checkpoint need to add the new key + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint + pretrain_cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + + # give model a chance to load something + # model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + if encoder_param_prefix is not None: + encoder_state = {k[len(encoder_param_prefix):]: v for k, v in checkpoint['state_dict'].items() if + k.startswith(encoder_param_prefix)} + else: + encoder_state = checkpoint['state_dict'] + encoder.load_state_dict(encoder_state, strict=strict) + finally: + cls._set_model_restore_state(is_being_restored=False) + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom} + references = itertools.chain.from_iterable([x['test_references'] for x in outputs]) + hypotheses = itertools.chain.from_iterable([x['test_hypotheses'] for x in outputs]) + test_logprob = [x['test_logprob'] for x in outputs] + test_logprob_len = [x['test_logprob_len'] for x in outputs] + test_logits = [x['test_logits'] for x in outputs] + return {'test_loss': val_loss_mean, 'log': tensorboard_logs, 'decode_results': (references, hypotheses), + 'test_logprob': test_logprob, 'test_logprob_len': test_logprob_len, 'test_logits': test_logits} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune_backup.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune_backup.py new file mode 100644 index 0000000000000000000000000000000000000000..5a12b0c35846800140296597fbbf55cf3697aae9 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune_backup.py @@ -0,0 +1,602 @@ +"""Modified from .ctc_finetune.py""" +import contextlib +import copy +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf, open_dict, ListConfig +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer import WER, WER_phone +from nemo.collections.asr.metrics.wer_bpe import WERBPE +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.parts.perturb import process_augmentations, RandomNoisePerturbation, AudioAugmentor +from nemo.utils import logging +from nemo.collections.asr.modules.wav2vec_modules import compute_mask_indices, GumbelVectorQuantizer + + +class VQCTCFinetuneModel(ASRModel): + """Todo: should modify""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + self.label_type = cfg.label_type + assert self.label_type in ['char', 'phone', 'bpe'] + if self.label_type == 'bpe': + self.use_bpe = True + assert cfg.tokenizer is not None + from nemo.collections.asr.parts.mixins import ASRBPEMixin + self.bpe = ASRBPEMixin() + self.bpe._setup_tokenizer(cfg.tokenizer, register_artifact=False) + self.tokenizer = self.bpe.tokenizer + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + assert len(cfg.decoder.vocabulary) == 0 + with open_dict(cfg): + cfg.decoder.vocabulary = ListConfig(list(vocabulary.values())) + else: + self.use_bpe = False + assert cfg.tokenizer is None + self.add_end_space = cfg.add_end_space + self.lang = cfg.lang + + super().__init__(cfg=cfg, trainer=trainer) + + if self._cfg.encoder_type == 'spec2vec': + from nemo.collections.asr.models.spec2vec.spec2vec_model import Spec2VecEncoder + self.encoder = Spec2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'spec2vec_encoder.' + elif self._cfg.encoder_type == 'st': + from nemo.collections.asr.models.st2vec.st2vec_model import ST2VecEncoder + self.encoder = ST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + elif self._cfg.encoder_type == 'feat_st': + from nemo.collections.asr.models.st2vec.st2vec_model import FeatST2VecEncoder + self.encoder = FeatST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + else: + assert self._cfg.encoder_type == 'wav2vec' + from nemo.collections.asr.modules.wav2vec_encoder import Wav2VecEncoderModel + self.encoder = Wav2VecEncoderModel(self._cfg.encoder) + encoder_param_prefix = None + if cfg.pretrain_chkpt_path is not None: + self.init_encoder_from_pretrain_model(self.encoder, encoder_param_prefix, cfg.pretrain_chkpt_path) + if self._cfg.encoder_type == 'st': + self.encoder.remove_pretraining_modules(use_teacher_encoder=self._cfg.use_teacher_encoder) + else: + self.encoder.remove_pretraining_modules() + + quantizer_input_dim = self.encoder.feature_encoder.output_dim + self.quantizer = GumbelVectorQuantizer( + dim=quantizer_input_dim, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=cfg.quantizer.latent_dim, + time_first=True, + ) + + self.decoder = self.from_config_dict(self._cfg.decoder) + + self.freeze_finetune_updates = self._cfg.freeze_finetune_updates + + self.ctc_loss = CTCLoss( + blank_id=self.decoder.blank_idx, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + self.quant_ppl_loss_weight = cfg.quant_ppl_loss_weight + + # Setup metric objects + if self.use_bpe: + self._wer = WERBPE( + tokenizer=self.tokenizer, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + lang=self.lang, + ) + else: + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + strip_end_space=self.add_end_space, + ) + + @torch.no_grad() + def transcribe( + self, paths2audio_files: List[str], batch_size: int = 4, logprobs=False, return_hypotheses: bool = False + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + # We will store transcriptions here + hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + # dither_value = self.preprocessor.featurizer.dither + # pad_to_value = self.preprocessor.featurizer.pad_to + dither_value = self.encoder.wav2spec.featurizer.dither + pad_to_value = self.encoder.wav2spec.featurizer.pad_to + + try: + # self.preprocessor.featurizer.dither = 0.0 + # self.preprocessor.featurizer.pad_to = 0 + self.encoder.wav2spec.featurizer.dither = 0.0 + self.encoder.wav2spec.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + # logits, logits_len, greedy_predictions = self( + # input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + # ) + + log_probs, logits_len, greedy_predictions, logits, prob_ppl_loss, cur_temp, prob_ppl = \ + self(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device), + global_step=self.trainer.global_step) + + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + hypotheses.append(logits[idx][: logits_len[idx]]) + else: + current_hypotheses = self._wer.ctc_decoder_predictions_tensor( + greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses, + ) + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += current_hypotheses + + del greedy_predictions + del logits + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + # self.preprocessor.featurizer.dither = dither_value + # self.preprocessor.featurizer.pad_to = pad_to_value + self.encoder.wav2spec.featurizer.dither = dither_value + self.encoder.wav2spec.featurizer.pad_to = pad_to_value + return hypotheses + + def change_vocabulary(self, new_vocabulary: List[str]): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + assert not self.use_bpe + if self.decoder.vocabulary == new_vocabulary: + logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.decoder + self.decoder = self.from_config_dict(new_decoder_config) + del self.ctc_loss + self.ctc_loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict], noise_perturb_config): + if noise_perturb_config is not None: + noise_perturb = RandomNoisePerturbation(**noise_perturb_config) + augmentor = AudioAugmentor(perturbations=[(1.0, noise_perturb)]) + else: + augmentor = None + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + if self.add_end_space: + config['parser_add_end_space'] = self.add_end_space + + if self.use_bpe: + dataset = audio_to_text_dataset.get_bpe_dataset(config=config, tokenizer=self.tokenizer, + augmentor=augmentor) + else: + if self.label_type == 'char': + dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) + elif self.label_type == 'phone': + dataset = audio_to_text_dataset.get_phone_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config, + noise_perturb_config=self._cfg['noise_perturb']) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, noise_perturb_config=None) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config, noise_perturb_config=None) + + def optim_param_groups(self): + return [{'params': self.encoder.parameters(), 'weight_decay': 0.0}, + {'params': self.decoder.parameters()}] + + def forward(self, input_signal, input_signal_length, global_step): + """ + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + """ + ft = False if global_step is None else self.freeze_finetune_updates <= global_step + with torch.no_grad() if not ft else contextlib.suppress(): + encoded, encoded_len = self.encoder(input_signal, input_signal_length, None, None, + mask=self.training, features_only=True) + + # [B, T, D] => [B, D, T] + encoded = encoded.transpose(1, 2) + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = encoded_len.max() + if encoded.shape[2] != max_output_len: + encoded = encoded.narrow(dim=2, start=0, length=max_output_len).contiguous() + + # [B, D, T] => [B, T, D] + encoded = encoded.transpose(1, 2) + self.quantizer.set_num_updates(global_step) + quantized_encoded, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(encoded) + # Todo: should exclude the units that are beyond the sequence length from quantization optimization + # [B, T, D] => [B, D, T] + quantized_encoded = quantized_encoded.transpose(1, 2) + + logits, encoded_len = self.decoder(encoder_output=quantized_encoded, lens=encoded_len, log_prob=False) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + with torch.no_grad(): + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + + return log_probs, encoded_len, greedy_predictions, logits, prob_ppl_loss, cur_temp, prob_ppl + + def training_step(self, batch, batch_nb): + signal, signal_len, transcript, transcript_len = batch + + log_probs, encoded_len, predictions, _, prob_ppl_loss, cur_temp, prob_ppl = \ + self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + if self.quant_ppl_loss_weight != 0: + weighted_prob_ppl_loss = self.quant_ppl_loss_weight * prob_ppl_loss + # Todo: check, this is mean value of batch + loss = ctc_loss + weighted_prob_ppl_loss + else: + loss = ctc_loss + + tensorboard_logs = {'train_loss': loss, 'ctc_loss': ctc_loss, 'weighted_prob_ppl_loss': weighted_prob_ppl_loss, + 'prob_ppl_loss': prob_ppl_loss, 'cur_temp': cur_temp, 'prob_ppl': prob_ppl, + 'learning_rate': self._optimizer.param_groups[0]['lr']} + + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, decode_results=None): + signal, signal_len, transcript, transcript_len = batch + with torch.no_grad(): + log_probs, encoded_len, predictions, logits, prob_ppl_loss, cur_temp, prob_ppl = \ + self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + if self.quant_ppl_loss_weight != 0: + weighted_prob_ppl_loss = self.quant_ppl_loss_weight * prob_ppl_loss + # Todo: check, this is mean value of batch + loss = ctc_loss + weighted_prob_ppl_loss + else: + loss = ctc_loss + + self._wer.update( + predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + log_prediction=batch_idx < 3, decode_results=decode_results) + wer, wer_num, wer_denom = self._wer.compute() + return { + 'val_loss': loss, + 'val_wer_num': wer_num, + 'val_wer_denom': wer_denom, + 'val_wer': wer, + 'val_logprob': log_probs.cpu().numpy(), + 'val_logprob_len': encoded_len.cpu().numpy(), + 'val_logits': logits.cpu().numpy(), + } + + def test_step(self, batch, batch_idx, dataloader_idx=0): + decode_results = {} + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx, decode_results=decode_results) + test_logs = { + 'test_loss': logs['val_loss'], + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + 'test_wer': logs['val_wer'], + 'test_references': decode_results['references'], + 'test_hypotheses': decode_results['hypotheses'], + 'test_logprob': logs['val_logprob'], + 'test_logprob_len': logs['val_logprob_len'], + 'test_logits': logs['val_logits'], + } + return test_logs + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + assert not self.use_bpe + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + # 'sample_rate': self.preprocessor._sample_rate, + 'sample_rate': self.encoder.wav2spec._sample_rate, + 'labels': self.decoder.vocabulary, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': True, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), + noise_perturb_config=None) + return temporary_datalayer + + @classmethod + def init_encoder_from_pretrain_model( + cls, + encoder, + encoder_param_prefix, + checkpoint_path, + *, + map_location=None, + strict: bool = True): + try: + cls._set_model_restore_state(is_being_restored=True) + + from pytorch_lightning.utilities.cloud_io import load as pl_load + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # for past checkpoint need to add the new key + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint + pretrain_cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + + # give model a chance to load something + # model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + if encoder_param_prefix is not None: + encoder_state = {k[len(encoder_param_prefix):]: v for k, v in checkpoint['state_dict'].items() if + k.startswith(encoder_param_prefix)} + else: + encoder_state = checkpoint['state_dict'] + encoder.load_state_dict(encoder_state, strict=strict) + finally: + cls._set_model_restore_state(is_being_restored=False) + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom} + references = itertools.chain.from_iterable([x['test_references'] for x in outputs]) + hypotheses = itertools.chain.from_iterable([x['test_hypotheses'] for x in outputs]) + test_logprob = [x['test_logprob'] for x in outputs] + test_logprob_len = [x['test_logprob_len'] for x in outputs] + test_logits = [x['test_logits'] for x in outputs] + return {'test_loss': val_loss_mean, 'log': tensorboard_logs, 'decode_results': (references, hypotheses), + 'test_logprob': test_logprob, 'test_logprob_len': test_logprob_len, 'test_logits': test_logits} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune_bak_240103.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune_bak_240103.py new file mode 100644 index 0000000000000000000000000000000000000000..502418d9809d7a90d8b26110f1dcab04a10f21b0 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_finetune_bak_240103.py @@ -0,0 +1,798 @@ +"""Modified from .ctc_finetune.py""" +import contextlib +import copy +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf, open_dict, ListConfig +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer import WER, WER_phone +from nemo.collections.asr.metrics.wer_bpe import WERBPE +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.parts.perturb import process_augmentations, RandomNoisePerturbation, AudioAugmentor +from nemo.utils import logging +from nemo.collections.asr.modules.wav2vec_modules import compute_mask_indices, GumbelVectorQuantizer +from nemo.core.classes.modelPT import inspect, optim, hydra, prepare_lr_scheduler + +import torch.nn as nn +from my_scripts.fsq import FSQ + +class VQCTCFinetuneModel(ASRModel): + """Todo: should modify""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + self.label_type = cfg.label_type + assert self.label_type in ['char', 'phone', 'bpe'] + if self.label_type == 'bpe': + self.use_bpe = True + assert cfg.tokenizer is not None + from nemo.collections.asr.parts.mixins import ASRBPEMixin + self.bpe = ASRBPEMixin() + self.bpe._setup_tokenizer(cfg.tokenizer, register_artifact=False) + self.tokenizer = self.bpe.tokenizer + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + assert len(cfg.decoder.vocabulary) == 0 + with open_dict(cfg): + cfg.decoder.vocabulary = ListConfig(list(vocabulary.values())) + else: + self.use_bpe = False + assert cfg.tokenizer is None + self.add_end_space = cfg.add_end_space + self.lang = cfg.lang + + super().__init__(cfg=cfg, trainer=trainer) + + if self._cfg.encoder_type == 'spec2vec': + from nemo.collections.asr.models.spec2vec.spec2vec_model import Spec2VecEncoder + self.encoder = Spec2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'spec2vec_encoder.' + elif self._cfg.encoder_type == 'st': + from nemo.collections.asr.models.st2vec.st2vec_model import ST2VecEncoder + self.encoder = ST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + elif self._cfg.encoder_type == 'feat_st': + from nemo.collections.asr.models.st2vec.st2vec_model import FeatST2VecEncoder + self.encoder = FeatST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + else: + assert self._cfg.encoder_type == 'wav2vec' + from nemo.collections.asr.modules.wav2vec_encoder import Wav2VecEncoderModel + self.encoder = Wav2VecEncoderModel(self._cfg.encoder) + encoder_param_prefix = None + if cfg.pretrain_chkpt_path is not None: + self.init_encoder_from_pretrain_model(self.encoder, encoder_param_prefix, cfg.pretrain_chkpt_path) + if self._cfg.encoder_type == 'st': + self.encoder.remove_pretraining_modules(use_teacher_encoder=self._cfg.use_teacher_encoder) + else: + self.encoder.remove_pretraining_modules() + + # # quantizer_input_dim = self.encoder.feature_encoder.output_dim + # self.pre_logits_bottleneck_layer = nn.Linear(self.encoder.feature_encoder.output_dim, cfg.quantizer.pre_logits_dim) + # + # self.quantizer = GumbelVectorQuantizer( + # dim=cfg.quantizer.pre_logits_dim, # quantizer_input_dim, + # num_vars=cfg.quantizer.latent_vars, + # temp=cfg.quantizer.latent_temp, + # groups=cfg.quantizer.latent_groups, + # combine_groups=False, + # vq_dim=cfg.quantizer.latent_dim, + # time_first=True, + # ) + + levels = cfg.quantizer.levels + self.pre_quant = nn.Linear(self.encoder.feature_encoder.output_dim, len(levels)) + self.quantizer = FSQ(levels) + + self.decoder = self.from_config_dict(self._cfg.decoder) + + self.freeze_finetune_updates = self._cfg.freeze_finetune_updates + + self.ctc_loss = CTCLoss( + blank_id=self.decoder.blank_idx, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + self.quant_ppl_loss_weight = cfg.quant_ppl_loss_weight + + # Setup metric objects + if self.use_bpe: + self._wer = WERBPE( + tokenizer=self.tokenizer, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + lang=self.lang, + ) + else: + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + strip_end_space=self.add_end_space, + ) + + @torch.no_grad() + def transcribe( + self, paths2audio_files: List[str], batch_size: int = 4, logprobs=False, return_hypotheses: bool = False + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + # We will store transcriptions here + hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + # dither_value = self.preprocessor.featurizer.dither + # pad_to_value = self.preprocessor.featurizer.pad_to + dither_value = self.encoder.wav2spec.featurizer.dither + pad_to_value = self.encoder.wav2spec.featurizer.pad_to + + try: + # self.preprocessor.featurizer.dither = 0.0 + # self.preprocessor.featurizer.pad_to = 0 + self.encoder.wav2spec.featurizer.dither = 0.0 + self.encoder.wav2spec.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + # logits, logits_len, greedy_predictions = self( + # input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + # ) + + # log_probs, logits_len, greedy_predictions, logits, prob_ppl_loss, cur_temp, prob_ppl = \ + # self(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device), + # global_step=self.trainer.global_step) + + log_probs, logits_len, greedy_predictions, logits = \ + self(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device), + global_step=self.trainer.global_step) + + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + hypotheses.append(logits[idx][: logits_len[idx]]) + else: + current_hypotheses = self._wer.ctc_decoder_predictions_tensor( + greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses, + ) + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += current_hypotheses + + del greedy_predictions + del logits + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + # self.preprocessor.featurizer.dither = dither_value + # self.preprocessor.featurizer.pad_to = pad_to_value + self.encoder.wav2spec.featurizer.dither = dither_value + self.encoder.wav2spec.featurizer.pad_to = pad_to_value + return hypotheses + + def change_vocabulary(self, new_vocabulary: List[str]): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + assert not self.use_bpe + if self.decoder.vocabulary == new_vocabulary: + logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.decoder + self.decoder = self.from_config_dict(new_decoder_config) + del self.ctc_loss + self.ctc_loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict], noise_perturb_config): + if noise_perturb_config is not None: + noise_perturb = RandomNoisePerturbation(**noise_perturb_config) + augmentor = AudioAugmentor(perturbations=[(1.0, noise_perturb)]) + else: + augmentor = None + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + if self.add_end_space: + config['parser_add_end_space'] = self.add_end_space + + if self.use_bpe: + dataset = audio_to_text_dataset.get_bpe_dataset(config=config, tokenizer=self.tokenizer, + augmentor=augmentor) + else: + if self.label_type == 'char': + dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) + elif self.label_type == 'phone': + dataset = audio_to_text_dataset.get_phone_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config, + noise_perturb_config=self._cfg['noise_perturb']) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, noise_perturb_config=None) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config, noise_perturb_config=None) + + # def optim_param_groups(self): + # return [{'params': self.encoder.parameters(), 'weight_decay': 0.0}, + # {'params': self.decoder.parameters()}] + + def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None): + """ + Prepares an optimizer from a string name and its optional config parameters. + + Args: + optim_config: A dictionary containing the following keys: + + * "lr": mandatory key for learning rate. Will raise ValueError if not provided. + * "optimizer": string name pointing to one of the available optimizers in the registry. \ + If not provided, defaults to "adam". + * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \ + The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \ + will be built and supplied to instantiate the optimizer. + """ + ## modifed by me!! + # If config was not explicitly passed to us + if optim_config is None: + # See if internal config has `optim` namespace + if self._cfg is not None and hasattr(self._cfg, 'optim'): + optim_config = self._cfg.optim + + # If config is still None, or internal config has no Optim, return without instantiation + if optim_config is None: + logging.info('No optimizer config provided, therefore no optimizer was created') + return + + else: + # Preserve the configuration + if not isinstance(optim_config, DictConfig): + optim_config = OmegaConf.create(optim_config) + + # See if internal config has `optim` namespace before preservation + if self._cfg is not None and hasattr(self._cfg, 'optim'): + if self._cfg.optim is None: + self._cfg.optim = copy.deepcopy(optim_config) + else: + with open_dict(self._cfg.optim): + self._cfg.optim = copy.deepcopy(optim_config) + + # Setup optimizer and scheduler + if optim_config is not None and isinstance(optim_config, DictConfig): + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + if 'sched' in optim_config and optim_config[ + 'sched'] is not None and self._trainer is not None: ## this line modified! + if not isinstance(self._trainer.accumulate_grad_batches, int): + raise ValueError("We do not currently support gradient acculumation that is not an integer.") + if self._trainer.max_steps is None: + # Store information needed to calculate max_steps + optim_config['sched']['t_max_epochs'] = self._trainer.max_epochs + optim_config['sched']['t_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches + optim_config['sched']['t_limit_train_batches'] = self._trainer.limit_train_batches + if self._trainer.distributed_backend is None: + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus or 1 + elif self._trainer.distributed_backend == "ddp_cpu": + optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes + elif self._trainer.distributed_backend == "ddp": + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + else: + logging.warning( + f"The lightning trainer received accelerator: {self._trainer.distributed_backend}. We " + "recommend to use 'ddp' instead." + ) + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + else: + optim_config['sched']['max_steps'] = self._trainer.max_steps + + # Force into DictConfig from nested structure + optim_config = OmegaConf.create(optim_config) + # Get back nested dict so we its mutable + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + # Extract scheduler config if inside optimizer config + if 'sched' in optim_config: + scheduler_config = optim_config.pop('sched') + else: + scheduler_config = None + + # Check if caller provided optimizer name, default to Adam otherwise + optimizer_cls = optim_config.get('_target_', None) + + if optimizer_cls is None: + # Try to get optimizer name for dynamic resolution, defaulting to Adam + optimizer_name = optim_config.get('name', 'adam') + else: + if inspect.isclass(optimizer_cls): + optimizer_name = optimizer_cls.__name__.lower() + else: + # resolve the class name (lowercase) from the class path if not provided + optimizer_name = optimizer_cls.split(".")[-1].lower() + + # We are guarenteed to have lr since it is required by the argparser + # But maybe user forgot to pass it to this function + lr = optim_config.get('lr', None) + + # Check if caller has optimizer kwargs, default to empty dictionary + if 'args' in optim_config: + optimizer_args = optim_config.pop('args') + optimizer_args = optim.parse_optimizer_args(optimizer_name, optimizer_args) + else: + optimizer_args = copy.deepcopy(optim_config) + + # Remove extra parameters from optimizer_args nest + # Assume all other parameters are to be passed into optimizer constructor + optimizer_args.pop('name', None) + optimizer_args.pop('cls', None) + optimizer_args.pop('lr', None) + + # Adaptive schedulers don't need `lr` + if lr is not None: + optimizer_args['lr'] = lr + + # Actually instantiate the optimizer + if optimizer_cls is not None: + if inspect.isclass(optimizer_cls): + optimizer = optimizer_cls(self.optim_param_groups(), **optimizer_args) + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + else: + # Attempt class path resolution + try: + optimizer_cls = OmegaConf.create({'_target_': optimizer_cls}) + if lr is not None: + optimizer_config = {'lr': lr} + else: + optimizer_config = {} + optimizer_config.update(optimizer_args) + + optimizer_instance = hydra.utils.instantiate( + optimizer_cls, self.optim_param_groups(), **optimizer_config + ) # type: DictConfig + + logging.info("Optimizer config = %s", str(optimizer_instance)) + + self._optimizer = optimizer_instance + + except Exception as e: + logging.error( + "Could not instantiate class path - {} with kwargs {}".format( + optimizer_cls, str(optimizer_config) + ) + ) + raise e + + else: + optimizer = optim.get_optimizer(optimizer_name) + optimizer = optimizer(self.optim_param_groups(), **optimizer_args) + + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + # Try to instantiate scheduler for optimizer + self._scheduler = prepare_lr_scheduler( + optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl + ) + + # Return the optimizer with/without scheduler + # This return allows multiple optimizers or schedulers to be created + return self._optimizer, self._scheduler + + def forward(self, input_signal, input_signal_length, global_step): + """ + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + """ + ft = False if global_step is None else self.freeze_finetune_updates <= global_step + with torch.no_grad() if not ft else contextlib.suppress(): + encoded, encoded_len = self.encoder(input_signal, input_signal_length, None, None, + mask=self.training, features_only=True) + + # [B, T, D] => [B, D, T] + encoded = encoded.transpose(1, 2) + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = encoded_len.max() + if encoded.shape[2] != max_output_len: + encoded = encoded.narrow(dim=2, start=0, length=max_output_len).contiguous() + + # [B, D, T] => [B, T, D] + encoded = encoded.transpose(1, 2) + + # self.quantizer.set_num_updates(global_step) + # + # encoded = self.pre_logits_bottleneck_layer(encoded) + # quantized_encoded, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(encoded) + + encoded = self.pre_quant(encoded) + quantized_encoded = self.quantizer.quantize(encoded) + + # Todo: should exclude the units that are beyond the sequence length from quantization optimization + # [B, T, D] => [B, D, T] + quantized_encoded = quantized_encoded.transpose(1, 2) + + logits, encoded_len = self.decoder(encoder_output=quantized_encoded, lens=encoded_len, log_prob=False) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + with torch.no_grad(): + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + + # return log_probs, encoded_len, greedy_predictions, logits, prob_ppl_loss, cur_temp, prob_ppl + return log_probs, encoded_len, greedy_predictions, logits + + def training_step(self, batch, batch_nb): + signal, signal_len, transcript, transcript_len = batch + + # log_probs, encoded_len, predictions, _, prob_ppl_loss, cur_temp, prob_ppl = \ + # self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + log_probs, encoded_len, predictions, _ = \ + self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + # if self.quant_ppl_loss_weight != 0: + # weighted_prob_ppl_loss = self.quant_ppl_loss_weight * prob_ppl_loss + # # Todo: check, this is mean value of batch + # loss = ctc_loss + weighted_prob_ppl_loss + # else: + # loss = ctc_loss + loss = ctc_loss + + # tensorboard_logs = {'train_loss': loss, 'ctc_loss': ctc_loss, 'weighted_prob_ppl_loss': weighted_prob_ppl_loss, + # 'prob_ppl_loss': prob_ppl_loss, 'cur_temp': cur_temp, 'prob_ppl': prob_ppl, + # 'learning_rate': self._optimizer.param_groups[0]['lr']} + tensorboard_logs = {'train_loss': loss, 'ctc_loss': ctc_loss, 'learning_rate': self._optimizer.param_groups[0]['lr']} + + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, decode_results=None): + signal, signal_len, transcript, transcript_len = batch + + with torch.no_grad(): + # log_probs, encoded_len, predictions, logits, prob_ppl_loss, cur_temp, prob_ppl = \ + # self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + log_probs, encoded_len, predictions, logits, = \ + self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + # if self.quant_ppl_loss_weight != 0: + # weighted_prob_ppl_loss = self.quant_ppl_loss_weight * prob_ppl_loss + # # Todo: check, this is mean value of batch + # loss = ctc_loss + weighted_prob_ppl_loss + # else: + # loss = ctc_loss + loss = ctc_loss + + self._wer.update( + predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + log_prediction=batch_idx < 3, decode_results=decode_results) + wer, wer_num, wer_denom = self._wer.compute() + + return { + 'val_loss': loss, + 'val_wer_num': wer_num, + 'val_wer_denom': wer_denom, + 'val_wer': wer, + 'val_logprob': log_probs.cpu().numpy(), + 'val_logprob_len': encoded_len.cpu().numpy(), + 'val_logits': logits.cpu().numpy(), + } + + def test_step(self, batch, batch_idx, dataloader_idx=0): + decode_results = {} + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx, decode_results=decode_results) + test_logs = { + 'test_loss': logs['val_loss'], + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + 'test_wer': logs['val_wer'], + 'test_references': decode_results['references'], + 'test_hypotheses': decode_results['hypotheses'], + 'test_logprob': logs['val_logprob'], + 'test_logprob_len': logs['val_logprob_len'], + 'test_logits': logs['val_logits'], + } + return test_logs + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + assert not self.use_bpe + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + # 'sample_rate': self.preprocessor._sample_rate, + 'sample_rate': self.encoder.wav2spec._sample_rate, + 'labels': self.decoder.vocabulary, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': True, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), + noise_perturb_config=None) + return temporary_datalayer + + @classmethod + def init_encoder_from_pretrain_model( + cls, + encoder, + encoder_param_prefix, + checkpoint_path, + *, + map_location=None, + strict: bool = True): + try: + cls._set_model_restore_state(is_being_restored=True) + + from pytorch_lightning.utilities.cloud_io import load as pl_load + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # for past checkpoint need to add the new key + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint + pretrain_cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + + # give model a chance to load something + # model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + if encoder_param_prefix is not None: + encoder_state = {k[len(encoder_param_prefix):]: v for k, v in checkpoint['state_dict'].items() if + k.startswith(encoder_param_prefix)} + else: + encoder_state = checkpoint['state_dict'] + encoder.load_state_dict(encoder_state, strict=strict) + finally: + cls._set_model_restore_state(is_being_restored=False) + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom} + references = itertools.chain.from_iterable([x['test_references'] for x in outputs]) + hypotheses = itertools.chain.from_iterable([x['test_hypotheses'] for x in outputs]) + test_logprob = [x['test_logprob'] for x in outputs] + test_logprob_len = [x['test_logprob_len'] for x in outputs] + test_logits = [x['test_logits'] for x in outputs] + return {'test_loss': val_loss_mean, 'log': tensorboard_logs, 'decode_results': (references, hypotheses), + 'test_logprob': test_logprob, 'test_logprob_len': test_logprob_len, 'test_logits': test_logits} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_recon_finetune.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_recon_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..132857aebbf06e21f4af484ed060b966b2921e9e --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/spec2vec/vq_ctc_recon_finetune.py @@ -0,0 +1,865 @@ +import contextlib +import copy +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union, Tuple +from dataclasses import field, dataclass +import torch +from torch import nn +from omegaconf import DictConfig, OmegaConf, open_dict, ListConfig, MISSING +from hydra.utils import instantiate +from pytorch_lightning import Trainer +from tqdm.auto import tqdm +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.wer import WER, WER_phone +from nemo.collections.asr.metrics.wer_bpe import WERBPE +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.parts.perturb import RandomNoisePerturbation, AudioAugmentor +from nemo.utils import logging +from nemo.core.classes.modelPT import inspect, optim, hydra, prepare_lr_scheduler +from nemo.collections.asr.modules.wav2vec_modules import GumbelVectorQuantizer +from nemo.collections.asr.parts.spec2vec import Projector +from nemo.collections.asr.models.spec2vec.spec2vec_config import ProjectorConfig +from nemo.collections.asr.models.st2vec.st2vec_model import create_padding_mask +from nemo.collections.asr.models.spec2vec.spec2vec_config import ST2VecVQCTCFinetuneModelConfig + + +@dataclass +class ReconstructorConfig: + generator: ProjectorConfig = MISSING + global_cond_extractor: Optional[ProjectorConfig] = None + reduction_rate: int = MISSING + loss_type: str = 'l2' + + +@dataclass +class ST2VecVQCTCReconFinetuneModelConfig(ST2VecVQCTCFinetuneModelConfig): + reconstruction_loss_weight: float = MISSING + reconstructor: ReconstructorConfig = MISSING + + +class Reconstrutor(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + self.generator = Projector(cfg.generator) + self.reduction_rate = cfg.reduction_rate + self.loss_type = cfg.loss_type + + if cfg.global_cond_extractor is not None: + self.global_cond_extractor = Projector(cfg.global_cond_extractor) + else: + self.global_cond_extractor = None + + def forward(self, h, h_len, target, target_len): + # [B, D, T] => [B, T, D] + target = target.transpose(1, 2) + + if self.global_cond_extractor is not None: + global_cond = self.get_global_cond(target, target_len) + else: + global_cond = None + + if self.global_cond_extractor is not None: + h = h + global_cond.unsqueeze(1) + + pred = self.generator(h, length=h_len) + + B, T, C = pred.size() + assert C // self.reduction_rate == target.shape[2] + pred = pred.reshape(B, T * self.reduction_rate, C // self.reduction_rate) + pred_lens = h_len * self.reduction_rate + + recon = pred + recon_lens = pred_lens + sample = recon, recon_lens, target, target_len + + if pred.shape[1] > target.shape[1]: + pred = pred.narrow(dim=1, start=0, length=target.shape[1]).contiguous() + elif pred.shape[1] < target.shape[1]: + target = target.narrow(dim=1, start=0, length=pred.shape[1]).contiguous() + + loss_mask = create_padding_mask(pred_lens, pred.shape[1]) + loss_mask = ~loss_mask + + pred = pred[loss_mask] + target = target[loss_mask] + mean = True + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + return {'loss': loss, 'sample': sample} + + def get_global_cond(self, input, input_len): + cond = self.global_cond_extractor(input, length=input_len) + pad_mask = create_padding_mask(input_len, input.shape[1]) + cond = cond.masked_fill(pad_mask.unsqueeze(2), 0.0) + cond = torch.sum(cond, axis=1) + cond = cond / input_len.unsqueeze(1) + return cond + + +class VQCTCReconFinetuneModel(ASRModel): + """Todo: should modify""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + self.label_type = cfg.label_type + assert self.label_type in ['char', 'phone', 'bpe'] + if self.label_type == 'bpe': + self.use_bpe = True + assert cfg.tokenizer is not None + from nemo.collections.asr.parts.mixins import ASRBPEMixin + self.bpe = ASRBPEMixin() + self.bpe._setup_tokenizer(cfg.tokenizer, register_artifact=False) + self.tokenizer = self.bpe.tokenizer + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + assert len(cfg.decoder.vocabulary) == 0 + with open_dict(cfg): + cfg.decoder.vocabulary = ListConfig(list(vocabulary.values())) + else: + self.use_bpe = False + assert cfg.tokenizer is None + self.add_end_space = cfg.add_end_space + self.lang = cfg.lang + + super().__init__(cfg=cfg, trainer=trainer) + + if self._cfg.encoder_type == 'spec2vec': + from nemo.collections.asr.models.spec2vec.spec2vec_model import Spec2VecEncoder + self.encoder = Spec2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'spec2vec_encoder.' + elif self._cfg.encoder_type == 'st': + from nemo.collections.asr.models.st2vec.st2vec_model import ST2VecEncoder + self.encoder = ST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + elif self._cfg.encoder_type == 'feat_st': + from nemo.collections.asr.models.st2vec.st2vec_model import FeatST2VecEncoder + self.encoder = FeatST2VecEncoder(self._cfg.encoder) + encoder_param_prefix = 'st2vec_encoder.' + else: + assert self._cfg.encoder_type == 'wav2vec' + from nemo.collections.asr.modules.wav2vec_encoder import Wav2VecEncoderModel + self.encoder = Wav2VecEncoderModel(self._cfg.encoder) + encoder_param_prefix = None + if cfg.pretrain_chkpt_path is not None: + self.init_encoder_from_pretrain_model(self.encoder, encoder_param_prefix, cfg.pretrain_chkpt_path) + if self._cfg.encoder_type == 'st': + self.encoder.remove_pretraining_modules(use_teacher_encoder=self._cfg.use_teacher_encoder) + else: + self.encoder.remove_pretraining_modules() + + quantizer_input_dim = self.encoder.feature_encoder.output_dim + self.quantizer = GumbelVectorQuantizer( + dim=quantizer_input_dim, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=cfg.quantizer.latent_dim, + time_first=True, + ) + + self.decoder = self.from_config_dict(self._cfg.decoder) + + self.freeze_finetune_updates = self._cfg.freeze_finetune_updates + + self.ctc_loss = CTCLoss( + blank_id=self.decoder.blank_idx, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + self.reconstructor = Reconstrutor(self._cfg.reconstructor) + + self.quant_ppl_loss_weight = cfg.quant_ppl_loss_weight + self.reconstruction_loss_weight = cfg.reconstruction_loss_weight + + # Setup metric objects + if self.use_bpe: + self._wer = WERBPE( + tokenizer=self.tokenizer, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + lang=self.lang, + ) + else: + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + blank_id=self.decoder.blank_idx, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + strip_end_space=self.add_end_space, + ) + + @torch.no_grad() + def transcribe( + self, paths2audio_files: List[str], batch_size: int = 4, logprobs=False, return_hypotheses: bool = False + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + # We will store transcriptions here + hypotheses = [] + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + # dither_value = self.preprocessor.featurizer.dither + # pad_to_value = self.preprocessor.featurizer.pad_to + dither_value = self.encoder.wav2spec.featurizer.dither + pad_to_value = self.encoder.wav2spec.featurizer.pad_to + + try: + # self.preprocessor.featurizer.dither = 0.0 + # self.preprocessor.featurizer.pad_to = 0 + self.encoder.wav2spec.featurizer.dither = 0.0 + self.encoder.wav2spec.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + # logits, logits_len, greedy_predictions = self( + # input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + # ) + + log_probs, logits_len, greedy_predictions, logits, prob_ppl_loss, cur_temp, prob_ppl = \ + self(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device), + global_step=self.trainer.global_step) + + if logprobs: + # dump log probs per file + for idx in range(logits.shape[0]): + hypotheses.append(logits[idx][: logits_len[idx]]) + else: + current_hypotheses = self._wer.ctc_decoder_predictions_tensor( + greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses, + ) + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += current_hypotheses + + del greedy_predictions + del logits + del test_batch + finally: + # set mode back to its original value + self.train(mode=mode) + # self.preprocessor.featurizer.dither = dither_value + # self.preprocessor.featurizer.pad_to = pad_to_value + self.encoder.wav2spec.featurizer.dither = dither_value + self.encoder.wav2spec.featurizer.pad_to = pad_to_value + return hypotheses + + def change_vocabulary(self, new_vocabulary: List[str]): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + assert not self.use_bpe + if self.decoder.vocabulary == new_vocabulary: + logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.decoder + self.decoder = self.from_config_dict(new_decoder_config) + del self.ctc_loss + self.ctc_loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + if self.label_type == 'phone': + WER_class = WER_phone + else: + WER_class = WER + self._wer = WER_class( + vocabulary=self.decoder.vocabulary, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + ctc_decode=True, + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + ) + + # Update config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") + + def _setup_dataloader_from_config(self, config: Optional[Dict], noise_perturb_config): + if noise_perturb_config is not None: + noise_perturb = RandomNoisePerturbation(**noise_perturb_config) + augmentor = AudioAugmentor(perturbations=[(1.0, noise_perturb)]) + else: + augmentor = None + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + if self.add_end_space: + config['parser_add_end_space'] = self.add_end_space + + if self.use_bpe: + dataset = audio_to_text_dataset.get_bpe_dataset(config=config, tokenizer=self.tokenizer, + augmentor=augmentor) + else: + if self.label_type == 'char': + dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) + elif self.label_type == 'phone': + dataset = audio_to_text_dataset.get_phone_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config, + noise_perturb_config=self._cfg['noise_perturb']) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, noise_perturb_config=None) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config, noise_perturb_config=None) + + def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None): + """ + Prepares an optimizer from a string name and its optional config parameters. + + Args: + optim_config: A dictionary containing the following keys: + + * "lr": mandatory key for learning rate. Will raise ValueError if not provided. + * "optimizer": string name pointing to one of the available optimizers in the registry. \ + If not provided, defaults to "adam". + * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \ + The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \ + will be built and supplied to instantiate the optimizer. + """ + ## modifed by me!! + # If config was not explicitly passed to us + if optim_config is None: + # See if internal config has `optim` namespace + if self._cfg is not None and hasattr(self._cfg, 'optim'): + optim_config = self._cfg.optim + + # If config is still None, or internal config has no Optim, return without instantiation + if optim_config is None: + logging.info('No optimizer config provided, therefore no optimizer was created') + return + + else: + # Preserve the configuration + if not isinstance(optim_config, DictConfig): + optim_config = OmegaConf.create(optim_config) + + # See if internal config has `optim` namespace before preservation + if self._cfg is not None and hasattr(self._cfg, 'optim'): + if self._cfg.optim is None: + self._cfg.optim = copy.deepcopy(optim_config) + else: + with open_dict(self._cfg.optim): + self._cfg.optim = copy.deepcopy(optim_config) + + # Setup optimizer and scheduler + if optim_config is not None and isinstance(optim_config, DictConfig): + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + if 'sched' in optim_config and optim_config[ + 'sched'] is not None and self._trainer is not None: ## this line modified! + if not isinstance(self._trainer.accumulate_grad_batches, int): + raise ValueError("We do not currently support gradient acculumation that is not an integer.") + if self._trainer.max_steps is None: + # Store information needed to calculate max_steps + optim_config['sched']['t_max_epochs'] = self._trainer.max_epochs + optim_config['sched']['t_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches + optim_config['sched']['t_limit_train_batches'] = self._trainer.limit_train_batches + if self._trainer.distributed_backend is None: + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus or 1 + elif self._trainer.distributed_backend == "ddp_cpu": + optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes + elif self._trainer.distributed_backend == "ddp": + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + else: + logging.warning( + f"The lightning trainer received accelerator: {self._trainer.distributed_backend}. We " + "recommend to use 'ddp' instead." + ) + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + else: + optim_config['sched']['max_steps'] = self._trainer.max_steps + + # Force into DictConfig from nested structure + optim_config = OmegaConf.create(optim_config) + # Get back nested dict so we its mutable + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + # Extract scheduler config if inside optimizer config + if 'sched' in optim_config: + scheduler_config = optim_config.pop('sched') + else: + scheduler_config = None + + # Check if caller provided optimizer name, default to Adam otherwise + optimizer_cls = optim_config.get('_target_', None) + + if optimizer_cls is None: + # Try to get optimizer name for dynamic resolution, defaulting to Adam + optimizer_name = optim_config.get('name', 'adam') + else: + if inspect.isclass(optimizer_cls): + optimizer_name = optimizer_cls.__name__.lower() + else: + # resolve the class name (lowercase) from the class path if not provided + optimizer_name = optimizer_cls.split(".")[-1].lower() + + # We are guarenteed to have lr since it is required by the argparser + # But maybe user forgot to pass it to this function + lr = optim_config.get('lr', None) + + # Check if caller has optimizer kwargs, default to empty dictionary + if 'args' in optim_config: + optimizer_args = optim_config.pop('args') + optimizer_args = optim.parse_optimizer_args(optimizer_name, optimizer_args) + else: + optimizer_args = copy.deepcopy(optim_config) + + # Remove extra parameters from optimizer_args nest + # Assume all other parameters are to be passed into optimizer constructor + optimizer_args.pop('name', None) + optimizer_args.pop('cls', None) + optimizer_args.pop('lr', None) + + # Adaptive schedulers don't need `lr` + if lr is not None: + optimizer_args['lr'] = lr + + # Actually instantiate the optimizer + if optimizer_cls is not None: + if inspect.isclass(optimizer_cls): + optimizer = optimizer_cls(self.optim_param_groups(), **optimizer_args) + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + else: + # Attempt class path resolution + try: + optimizer_cls = OmegaConf.create({'_target_': optimizer_cls}) + if lr is not None: + optimizer_config = {'lr': lr} + else: + optimizer_config = {} + optimizer_config.update(optimizer_args) + + optimizer_instance = hydra.utils.instantiate( + optimizer_cls, self.optim_param_groups(), **optimizer_config + ) # type: DictConfig + + logging.info("Optimizer config = %s", str(optimizer_instance)) + + self._optimizer = optimizer_instance + + except Exception as e: + logging.error( + "Could not instantiate class path - {} with kwargs {}".format( + optimizer_cls, str(optimizer_config) + ) + ) + raise e + + else: + optimizer = optim.get_optimizer(optimizer_name) + optimizer = optimizer(self.optim_param_groups(), **optimizer_args) + + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + # Try to instantiate scheduler for optimizer + self._scheduler = prepare_lr_scheduler( + optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl + ) + + # Return the optimizer with/without scheduler + # This return allows multiple optimizers or schedulers to be created + return self._optimizer, self._scheduler + + def forward(self, input_signal, input_signal_length, global_step): + """ + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + """ + ft = False if global_step is None else self.freeze_finetune_updates <= global_step + with torch.no_grad() if not ft else contextlib.suppress(): + encoded, encoded_len = self.encoder(input_signal, input_signal_length, None, None, + mask=self.training, features_only=True) + + # [B, T, D] => [B, D, T] + encoded = encoded.transpose(1, 2) + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = encoded_len.max() + if encoded.shape[2] != max_output_len: + encoded = encoded.narrow(dim=2, start=0, length=max_output_len).contiguous() + + # [B, D, T] => [B, T, D] + encoded = encoded.transpose(1, 2) + self.quantizer.set_num_updates(global_step) + quantized_encoded, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(encoded) + # Todo: should exclude the units that are beyond the sequence length from quantization optimization + # [B, T, D] => [B, D, T] + quantized_encoded = quantized_encoded.transpose(1, 2) + + logits, encoded_len = self.decoder(encoder_output=quantized_encoded, lens=encoded_len, log_prob=False) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + with torch.no_grad(): + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + + # reconstruction from quantiszed vector! + # [B, D,T] => [B, T, D] + quantized_encoded = quantized_encoded.transpose(1, 2) + specs, specs_len = self.encoder.wav2spec(input_signal=input_signal, length=input_signal_length) + recon_dict = self.reconstructor(quantized_encoded, encoded_len, specs, specs_len) + + return log_probs, encoded_len, greedy_predictions, logits, prob_ppl_loss, cur_temp, prob_ppl, recon_dict + + def training_step(self, batch, batch_nb): + signal, signal_len, transcript, transcript_len = batch + + log_probs, encoded_len, predictions, _, prob_ppl_loss, cur_temp, prob_ppl, recon_dict = \ + self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + + loss = 0 + loss += ctc_loss + if self.quant_ppl_loss_weight != 0: + weighted_prob_ppl_loss = self.quant_ppl_loss_weight * prob_ppl_loss + loss += weighted_prob_ppl_loss + if self.reconstruction_loss_weight != 0: + reconstruction_loss = recon_dict['loss'] + weighted_reconstruction_loss = self.reconstruction_loss_weight * reconstruction_loss + loss += weighted_reconstruction_loss + + tensorboard_logs = {'train_loss': loss, 'ctc_loss': ctc_loss, 'weighted_prob_ppl_loss': weighted_prob_ppl_loss, + 'prob_ppl_loss': prob_ppl_loss, 'cur_temp': cur_temp, 'prob_ppl': prob_ppl, + 'weighted_reconstruction_loss': weighted_reconstruction_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr']} + + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, decode_results=None): + signal, signal_len, transcript, transcript_len = batch + with torch.no_grad(): + log_probs, encoded_len, predictions, logits, prob_ppl_loss, cur_temp, prob_ppl, recon_dict = \ + self(input_signal=signal, input_signal_length=signal_len, global_step=self.trainer.global_step) + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + loss = 0 + loss += ctc_loss + if self.quant_ppl_loss_weight != 0: + weighted_prob_ppl_loss = self.quant_ppl_loss_weight * prob_ppl_loss + # Todo: check, this is mean value of batch + loss += weighted_prob_ppl_loss + if self.reconstruction_loss_weight != 0: + reconstruction_loss = recon_dict['loss'] + weighted_reconstruction_loss = self.reconstruction_loss_weight * reconstruction_loss + loss += weighted_reconstruction_loss + + self._wer.update( + predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + log_prediction=batch_idx < 3, decode_results=decode_results) + wer, wer_num, wer_denom = self._wer.compute() + return { + 'val_loss': loss, + 'val_wer_num': wer_num, + 'val_wer_denom': wer_denom, + 'val_wer': wer, + 'val_logprob': log_probs.cpu().numpy(), + 'val_logprob_len': encoded_len.cpu().numpy(), + 'val_logits': logits.cpu().numpy(), + } + + def test_step(self, batch, batch_idx, dataloader_idx=0): + decode_results = {} + logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx, decode_results=decode_results) + test_logs = { + 'test_loss': logs['val_loss'], + 'test_wer_num': logs['val_wer_num'], + 'test_wer_denom': logs['val_wer_denom'], + 'test_wer': logs['val_wer'], + 'test_references': decode_results['references'], + 'test_hypotheses': decode_results['hypotheses'], + 'test_logprob': logs['val_logprob'], + 'test_logprob_len': logs['val_logprob_len'], + 'test_logits': logs['val_logits'], + } + return test_logs + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + assert not self.use_bpe + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + # 'sample_rate': self.preprocessor._sample_rate, + 'sample_rate': self.encoder.wav2spec._sample_rate, + 'labels': self.decoder.vocabulary, + 'batch_size': min(config['batch_size'], len(config['paths2audio_files'])), + 'trim_silence': True, + 'shuffle': False, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), + noise_perturb_config=None) + return temporary_datalayer + + @classmethod + def init_encoder_from_pretrain_model( + cls, + encoder, + encoder_param_prefix, + checkpoint_path, + *, + map_location=None, + strict: bool = True): + try: + cls._set_model_restore_state(is_being_restored=True) + + from pytorch_lightning.utilities.cloud_io import load as pl_load + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # for past checkpoint need to add the new key + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint + pretrain_cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + + # give model a chance to load something + # model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + if encoder_param_prefix is not None: + encoder_state = {k[len(encoder_param_prefix):]: v for k, v in checkpoint['state_dict'].items() if + k.startswith(encoder_param_prefix)} + else: + encoder_state = checkpoint['state_dict'] + encoder.load_state_dict(encoder_state, strict=strict) + finally: + cls._set_model_restore_state(is_being_restored=False) + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom} + references = itertools.chain.from_iterable([x['test_references'] for x in outputs]) + hypotheses = itertools.chain.from_iterable([x['test_hypotheses'] for x in outputs]) + test_logprob = [x['test_logprob'] for x in outputs] + test_logprob_len = [x['test_logprob_len'] for x in outputs] + test_logits = [x['test_logits'] for x in outputs] + return {'test_loss': val_loss_mean, 'log': tensorboard_logs, 'decode_results': (references, hypotheses), + 'test_logprob': test_logprob, 'test_logprob_len': test_logprob_len, 'test_logits': test_logits} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7d8d936d5b12864342ff229ee59967cfd367fc0a --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_config.py @@ -0,0 +1,126 @@ +from typing import Optional, Any + +from dataclasses import field, dataclass +from omegaconf import MISSING + +from nemo.collections.asr.models.spec2vec.spec2vec_config import FeatureEncoderConfig, ProjectorConfig, \ + NoisePerturbConfig +from nemo.collections.asr.models.wav2vec.wav2vec_config import LossConfig, Wav2VecTransformerConfig, \ + Wav2VecMaskingConfig, QuantizerConfig +from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessorConfig +from nemo.core.config.modelPT import ModelConfig + + +@dataclass +class ShiftPerturbConfig: + dist: str = 'uniform' + shift_prob: float = MISSING + max_ratio: float = 0.5 + unit: int = MISSING + max: Optional[int] = None + min: Optional[int] = None + mean: Optional[float] = None + std: Optional[float] = None + truncate: bool = True + + +@dataclass +class PitchEstimationTask: + pitch_estimator: Optional[ProjectorConfig] = None + sample_rate: int = 16000 + pitch_min: int = 80 + pitch_max: int = 400 + reduction_rate: int = MISSING + loss_type: str = 'l2' + + +@dataclass +class ReconstructionTask: + quantizer: QuantizerConfig = MISSING + reconstructor: ProjectorConfig = MISSING + global_cond_extractor: Optional[ProjectorConfig] = None + reduction_rate: int = MISSING + use_teacher_feat_prob: float = 0.0 + loss_type: str = 'l2' + + +@dataclass +class ST2VecEncoderConfig: + preprocessor: AudioToMelSpectrogramPreprocessorConfig = MISSING + + feature_encoder: FeatureEncoderConfig = FeatureEncoderConfig() + pretrained_encoder_path: Optional[str] = None + freeze_feature_encoder: bool = False + freeze_student: bool = False + noise_mix_ratio: Optional[float] = None + masking: Optional[Wav2VecMaskingConfig] = None + shifting: Optional[ShiftPerturbConfig] = None + target_shifting: Optional[ShiftPerturbConfig] = None + target_masking: Optional[Wav2VecMaskingConfig] = None + target_compute_perturb: bool = False + + target_momentum: float = 0.99 + target_momentum_final: Optional[float] = None + target_momentum_steps: Optional[int] = None + target_momentum_type: Optional[str] = None + projector: Optional[ProjectorConfig] = None + predictor: Optional[ProjectorConfig] = None + + quantizer: Optional[QuantizerConfig] = None + + n_negatives: int = field( + default=100, metadata={'help': 'Number of negatives to sample from the same audio sample'} + ) + cross_sample_negatives: int = field( + default=0, metadata={'help': 'Number of negatives to sample from any sample in the batch'} + ) + codebook_negatives: int = field(default=0, metadata={'help': 'Number of negative examples in codebook'}) + negatives_from_everywhere: bool = field( + default=False, metadata={'help': 'Sample negatives from everywhere, not just masked states'} + ) + negatives_from_noisy_features: bool = False + + pitch_estimation_task: Optional[PitchEstimationTask] = None + pitch_loss_weight: float = 0.0 + + reconstruction_task: Optional[ReconstructionTask] = None + reconstruction_loss_weight: float = 0.0 + reconstruction_quant_ppl_loss_weight: float = 0.0 + +@dataclass +class FeatST2VecEncoderConfig: + preprocessor: AudioToMelSpectrogramPreprocessorConfig = MISSING + + feature_encoder: FeatureEncoderConfig = FeatureEncoderConfig() + context_net: Wav2VecTransformerConfig = MISSING + masking: Optional[Wav2VecMaskingConfig] = None + target_masking: Optional[Wav2VecMaskingConfig] = None + + target_momentum: float = 0.99 + predictor: Optional[ProjectorConfig] = None + + n_negatives: int = field( + default=100, metadata={'help': 'Number of negatives to sample from the same audio sample'} + ) + cross_sample_negatives: int = field( + default=0, metadata={'help': 'Number of negatives to sample from any sample in the batch'} + ) + codebook_negatives: int = field(default=0, metadata={'help': 'Number of negative examples in codebook'}) + negatives_from_everywhere: bool = field( + default=False, metadata={'help': 'Sample negatives from everywhere, not just masked states'} + ) + negatives_from_noisy_features: bool = False + + +@dataclass +class ST2VecPretrainModelConfig(ModelConfig): + encoder_type: str = 'st' + st2vec_encoder: Any = MISSING + + noise_perturb: Optional[NoisePerturbConfig] = None + + loss_type: str = 'wav2vec' + logit_temp: float = field(default=0.1, metadata={'help': 'Temperature to divide logits by'}) + loss: LossConfig = LossConfig() + + expected_gpu_num: int = 1 diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_model.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e264d6e12edf2af7bcb9e0cb4515201c6fd4f36e --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_model.py @@ -0,0 +1,1030 @@ +import contextlib +import math + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import nn + +from nemo.collections.asr.models.st2vec.st2vec_config import ShiftPerturbConfig +# from nemo.collections.asr.modules import yin +from nemo.collections.asr.modules.wav2vec_modules import compute_mask_indices, GumbelVectorQuantizer +from nemo.collections.asr.parts.spec2vec import Projector +from nemo.collections.asr.parts.spectr_augment import GAUSSIAN_MASK +from nemo.collections.asr.parts.wav2vec import TransformerEncoder +from nemo.core.classes.common import Serialization + + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +class ST2VecEncoder(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + self.wav2spec = Serialization.from_config_dict(cfg.preprocessor) + + self.freeze_student = cfg.freeze_student + self.feature_encoder = Serialization.from_config_dict(cfg.feature_encoder) + if cfg.pretrained_encoder_path: + print('load feature_encoder from:', cfg.pretrained_encoder_path) + self.feature_encoder.load_state_dict( + get_state_dict('st2vec_encoder.feature_encoder.', cfg.pretrained_encoder_path)) + if self.freeze_student: + stop_grad(self.feature_encoder) + + self.mask_cfg = cfg.masking + self.target_mask_cfg = cfg.target_masking + + self.n_negatives = cfg.n_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere + self.negatives_from_noisy_features = cfg.negatives_from_noisy_features + + if self.mask_cfg.mask_emb_type == 'zero': + self.mask_emb = 0.0 + else: + assert self.mask_cfg.mask_emb_type == 'gaussian' + num_features = cfg.preprocessor.features + self.register_buffer("mask_emb", torch.tensor(GAUSSIAN_MASK[:num_features])) + + if cfg.shifting is not None: + self.random_shift = RandomShift(cfg.shifting) + else: + self.random_shift = None + + if cfg.target_shifting is not None: + assert not cfg.target_shifting.truncate + assert cfg.target_shifting.dist == 'uniform' and cfg.target_shifting.min >= 0 + self.target_shifting = RandomShift(cfg.target_shifting) + else: + self.target_shifting = None + + cfg.projector.input_dim = self.feature_encoder.output_dim + self.projector = Projector(cfg.projector) + if cfg.pretrained_encoder_path: + print('load projector from:', cfg.pretrained_encoder_path) + self.projector.load_state_dict(get_state_dict('st2vec_encoder.projector.', cfg.pretrained_encoder_path)) + if self.freeze_student: + stop_grad(self.projector) + + self.target_compute_perturb = cfg.target_compute_perturb + + self.target_update_step = 0 + if cfg.target_momentum > 0: + self.target_feature_encoder = Serialization.from_config_dict(cfg.feature_encoder) + self.target_feature_encoder.load_state_dict(self.feature_encoder.state_dict()) + for p in self.target_feature_encoder.parameters(): + p.requires_grad = False + + self.target_projector = Projector(cfg.projector) + self.target_projector.load_state_dict(self.projector.state_dict()) + for p in self.target_projector.parameters(): + p.requires_grad = False + + if cfg.target_momentum_final is not None: + assert cfg.target_momentum_steps is not None + self.momentum_schedule = momentum_scheduler(cfg.target_momentum, cfg.target_momentum_final, + cfg.target_momentum_steps, type=cfg.target_momentum_type) + else: + self.momentum_schedule = lambda _: cfg.target_momentum + else: + self.target_feature_encoder = None + self.target_projector = None + self.momentum_schedule = None + + if cfg.predictor is not None: + cfg.predictor.input_dim = self.projector.output_dim + self.predictor = Projector(cfg.predictor) + if cfg.pretrained_encoder_path: + print('load predictor from:', cfg.pretrained_encoder_path) + self.predictor.load_state_dict(get_state_dict('st2vec_encoder.predictor.', cfg.pretrained_encoder_path)) + final_dim = self.predictor.output_dim + if self.freeze_student: + stop_grad(self.predictor) + else: + self.predictor = None + final_dim = self.projector.output_dim + + if cfg.quantizer is not None: + vq_dim = cfg.quantizer.latent_dim if cfg.quantizer.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.projector.output_dim, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_q = nn.Linear(vq_dim, final_dim) # Todo: can utilize later + else: + self.quantizer = None + self.project_q = None + + self.pitch_pred_cfg = cfg.pitch_estimation_task + if self.pitch_pred_cfg: + self.pitch_pred_cfg.pitch_estimator.input_dim = self.feature_encoder.output_dim + self.pitch_estimator = Projector(self.pitch_pred_cfg.pitch_estimator) + + self.reconstruction_cfg = cfg.reconstruction_task + if self.reconstruction_cfg: + self.reconstruction_cfg.reconstructor.input_dim = self.feature_encoder.output_dim + if self.reconstruction_cfg.global_cond_extractor: + self.reconstruction_cfg.global_cond_extractor.input_dim = num_features + self.reconstructor = Reconstrutor(self.reconstruction_cfg) + + def forward(self, wavs, wav_lens, p_wavs, p_wav_lens, *, mask=True, features_only=False, global_step=None) -> tuple: + # specs: [B, C, T] + unmasked_specs, unmasked_specs_len = self.wav2spec(input_signal=wavs, length=wav_lens) + + if self.reconstruction_cfg: + recon_target = unmasked_specs.clone() + recon_target_len = unmasked_specs_len.clone() + + if p_wavs is None: + if features_only: + specs = unmasked_specs + specs_len = unmasked_specs_len + else: + specs = unmasked_specs.clone() + specs_len = unmasked_specs_len.clone() + else: + specs, specs_len = self.wav2spec(input_signal=p_wavs, length=p_wav_lens) + + specs_mask = create_padding_mask(unmasked_specs_len, unmasked_specs.shape[2]) if mask else None + + if not features_only: + if self.training and self.target_shifting is not None: + unmasked_specs = unmasked_specs.transpose(1, 2) + unmasked_specs, unmasked_specs_len, target_shift_num, _, target_r_shift_num = ( + self.target_shifting.shift(unmasked_specs, unmasked_specs_len, self.mask_emb) + ) + unmasked_specs = unmasked_specs.transpose(1, 2) + else: + target_shift_num = 0 + target_r_shift_num = 0 + + if self.training and self.target_mask_cfg is not None: + unmasked_specs = unmasked_specs.transpose(1, 2) + unmasked_specs, _, _ = apply_mask(self.target_mask_cfg, unmasked_specs, specs_mask, self.mask_emb) + unmasked_specs = unmasked_specs.transpose(1, 2) + + if self.momentum_schedule is not None: + assert global_step is not None + target_momentum = self.momentum_schedule(global_step) + if global_step > self.target_update_step: + ema_update(self.target_feature_encoder, self.feature_encoder, target_momentum) + ema_update(self.target_projector, self.projector, target_momentum) + self.target_update_step = global_step + target_feature_encoder = self.target_feature_encoder + target_projector = self.target_projector + else: + target_feature_encoder = self.feature_encoder + target_projector = self.projector + with torch.no_grad(): + with as_eval(target_feature_encoder, + target_projector) if not self.target_compute_perturb else contextlib.suppress(): + unmasked_features, unmasked_feature_lens, _ = target_feature_encoder(unmasked_specs, + unmasked_specs_len) + # [B, D, T] => [B, T, D] + unmasked_features = unmasked_features.transpose(1, 2) + + teacher_encoder_features = unmasked_features + teacher_encoder_features_len = unmasked_feature_lens + + unmasked_features = target_projector(unmasked_features, length=unmasked_feature_lens) + + if target_shift_num > 0: + unmasked_features = unmasked_features[:, target_shift_num:] + unmasked_feature_lens = unmasked_feature_lens - target_shift_num + if self.reconstruction_cfg and self.reconstruction_cfg.use_teacher_feat_prob: + teacher_encoder_features = teacher_encoder_features[:, target_shift_num:] + teacher_encoder_features_len = teacher_encoder_features_len - target_shift_num + + else: + assert target_shift_num == 0 + + if target_r_shift_num > 0: + unmasked_features = unmasked_features[:, :-target_r_shift_num] + unmasked_feature_lens = unmasked_feature_lens - target_r_shift_num + if self.reconstruction_cfg and self.reconstruction_cfg.use_teacher_feat_prob: + teacher_encoder_features = teacher_encoder_features[:, :-target_r_shift_num] + teacher_encoder_features_len = teacher_encoder_features_len - target_r_shift_num + else: + assert target_r_shift_num == 0 + + if self.random_shift is not None and not features_only: + specs = specs.transpose(1, 2) + specs, specs_len, shift_num, _, r_shift_num = self.random_shift.shift(specs, specs_len, self.mask_emb) + specs = specs.transpose(1, 2) + else: + shift_num = 0 + r_shift_num = 0 + + if mask: + specs = specs.transpose(1, 2) + specs, _, _ = apply_mask(self.mask_cfg, specs, specs_mask, self.mask_emb) + specs = specs.transpose(1, 2) + + features, feature_lens, _ = self.feature_encoder(specs, specs_len) + # [B, D, T] => [B, T, D] + features = features.transpose(1, 2) + if features_only: + return features, feature_lens + + assert mask + + encoder_features = features + encoder_features_len = feature_lens + + features = self.projector(features, length=feature_lens) + + if self.predictor is not None: + pred_features = self.predictor(features, length=feature_lens) + else: + pred_features = features + + if shift_num > 0: + # remove paddings introduced by shift + pred_features = pred_features[:, shift_num:] + feature_lens = feature_lens - shift_num + elif shift_num < 0: + assert self.reconstruction_cfg is None + unmasked_features = unmasked_features[:, abs(shift_num):] + unmasked_feature_lens -= abs(shift_num) + + if r_shift_num > 0: + pred_features = pred_features[:, :-r_shift_num] + feature_lens = feature_lens - r_shift_num + elif r_shift_num < 0: + assert self.reconstruction_cfg is None + unmasked_features = unmasked_features[:, :-abs(r_shift_num)] + unmasked_feature_lens -= abs(r_shift_num) + + assert pred_features.shape[1] == unmasked_features.shape[1] + assert torch.equal(feature_lens, unmasked_feature_lens) + + padding_mask = create_padding_mask(feature_lens, pred_features.shape[1]) + features_mask = ~padding_mask + + pred_features = pred_features[features_mask] + # fake batch dim to 1 + pred_features = pred_features.view(1, -1, pred_features.size(-1)) + + unmasked_features = unmasked_features[features_mask] + # fake batch dim 1 + unmasked_features = unmasked_features.view(1, -1, unmasked_features.size(-1)) + + assert not unmasked_features.requires_grad + if self.quantizer is not None: + self.quantizer.set_num_updates(global_step) + unmasked_features, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(unmasked_features) + unmasked_features = self.project_q(unmasked_features) + else: + prob_ppl_loss, cur_temp, prob_ppl = None, None, None + + sampled_negatives, _ = self.sample_negatives_flat(unmasked_features, feature_lens.tolist()) + + if self.pitch_pred_cfg: + pitch_loss = self.get_pitch_pred_loss(encoder_features, encoder_features_len, wavs, wav_lens) + else: + pitch_loss = None + + if self.reconstruction_cfg: + recon = self.reconstructor(encoder_features, encoder_features_len, teacher_encoder_features, + teacher_encoder_features_len, + recon_target, recon_target_len, global_step) + else: + recon = None + + return pred_features, unmasked_features, sampled_negatives, padding_mask, prob_ppl_loss, cur_temp, prob_ppl, pitch_loss, recon + + def get_pitch_pred_loss(self, features, feature_lens, wavs, wav_lens, mean=True): + target = yin.estimate(wavs, sample_rate=self.pitch_pred_cfg.sample_rate, + pitch_min=self.pitch_pred_cfg.pitch_min, + pitch_max=self.pitch_pred_cfg.pitch_max) + # target = wavs[:, :int(wavs.shape[1] / 160)] + target = target / self.pitch_pred_cfg.pitch_max + + pred = self.pitch_estimator(features, length=feature_lens) + + B, T, C = pred.size() + assert C == self.pitch_pred_cfg.reduction_rate + pred = pred.reshape(B, T * C) + pred_lens = feature_lens * self.pitch_pred_cfg.reduction_rate + + # import random + # if random.random() < 0.001: + # torch.set_printoptions(profile="full") + # print('\ntgtt:', target[0][:300]) + # print('pred:', pred[0][:300]) + # torch.set_printoptions(profile="default") # reset + + if pred.shape[1] > target.shape[1]: + pred = pred.narrow(dim=1, start=0, length=target.shape[1]).contiguous() + elif pred.shape[1] < target.shape[1]: + target = target.narrow(dim=1, start=0, length=pred.shape[1]).contiguous() + + loss_mask = create_padding_mask(pred_lens, pred.shape[1]) + loss_mask = ~loss_mask + + pred = pred[loss_mask] + target = target[loss_mask] + if self.pitch_pred_cfg.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.pitch_pred_cfg.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + return loss + + def check_collapse(self, features, feature_lens, unmasked_features, proj_features): + import torch.nn.functional as F + trunc_len = min(feature_lens.min(), 80) + + feat_0_trunc = features[0, :trunc_len] + # [T, 1, C] + feat_0_trunc_src = feat_0_trunc.unsqueeze(1) + # [1, T, C] + feat_0_trunc_tgt = feat_0_trunc.unsqueeze(0) + feat_self_sim = F.cosine_similarity(feat_0_trunc_src, feat_0_trunc_tgt, dim=-1) + print('feat_self_sim: \n', feat_self_sim) + + proj_feat_0_trunc = proj_features[0, :trunc_len] + feat_proj_sim = F.cosine_similarity(feat_0_trunc, proj_feat_0_trunc, dim=-1) + print('feat_proj_sim: \n', feat_proj_sim) + + um_feat_0_trunc = unmasked_features[0, :trunc_len] + feat_um_sim = F.cosine_similarity(feat_0_trunc, um_feat_0_trunc, dim=-1) + print('feat_um_sim: \n', feat_um_sim) + + proj_um_sim = F.cosine_similarity(proj_feat_0_trunc, um_feat_0_trunc, dim=-1) + print('proj_um_sim: \n', proj_um_sim) + + feat_1_trunc = features[1, :trunc_len] + feat_cross_sim = F.cosine_similarity(feat_0_trunc, feat_1_trunc, dim=-1) + print('feat_cross_sim: \n', feat_cross_sim) + + def extract_features(self, source, audio_lengths, mask=False): + padding_mask = create_padding_mask(audio_lengths, max_len=source.shape[1]) + return self(source=source, padding_mask=padding_mask, mask=mask, features_only=True) + + def remove_pretraining_modules(self, use_teacher_encoder=False): + self.projector = None + if use_teacher_encoder: + print('use target feature encoder!', flush=True) + self.feature_encoder.load_state_dict(self.target_feature_encoder.state_dict()) + self.target_feature_encoder = None + self.target_projector = None + self.predictor = None + self.quantizer = None + self.project_q = None + + self.pitch_estimator = None + self.reconstructor = None + + def _update_quantizer_temp(self, global_step): + if self.quantize_targets: + self.quantizer.set_num_updates(global_step) + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz, tsz, fsz}" + + if self.n_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten() + + cross_neg_idxs = torch.randint( + low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def sample_negatives_flat(self, y, nums): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + assert bsz == 1 and tsz == sum(nums) # fake batch dim + y = y.view(-1, fsz) # BTC => (BxT)C + + # cross_high = tsz * bsz + + neg_idxs_l = [] + idx_start = 0 + with torch.no_grad(): + for i, num_i in enumerate(nums): + assert num_i > 1, f"{bsz, tsz, fsz}" + + assert self.n_negatives > 0 + tszs_i = buffered_arange(num_i).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + high_i = num_i + neg_idxs_i = torch.randint(low=0, high=high_i - 1, size=(self.n_negatives * num_i,)) + neg_idxs_i[neg_idxs_i >= tszs_i] += 1 + + neg_idxs_i += idx_start + idx_start += num_i + + neg_idxs_l.append(neg_idxs_i) + + assert self.cross_sample_negatives == 0 + + neg_idxs = torch.cat(neg_idxs_l) + assert neg_idxs.ndim == 1 + + negs = y[neg_idxs] + negs = negs.view(bsz, sum(nums), self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + +def create_padding_mask(audio_lengths, max_len): + # Broadcast to vectorize creating the padding mask + padding_mask = torch.arange(max_len, device=audio_lengths.device) + padding_mask = padding_mask.expand(len(audio_lengths), max_len) < audio_lengths.unsqueeze(1) + # Negate to false where no padding + padding_mask = ~padding_mask + return padding_mask + + +class RandomShift: + def __init__(self, cfg: ShiftPerturbConfig): + self.dist = cfg.dist + if self.dist == 'uniform': + assert isinstance(cfg.max, int) and isinstance(cfg.min, int) + self.min = cfg.min + self.max = cfg.max + else: + assert cfg.dist == 'rounded_normal' + assert isinstance(cfg.mean, float) and isinstance(cfg.std, float) + self.mean = cfg.mean + self.std = cfg.std + self.max_ratio = cfg.max_ratio + assert isinstance(cfg.unit, int) + self.unit = cfg.unit + self.shift_prob = cfg.shift_prob + self.truncate = cfg.truncate + + def shift(self, inputs, inputs_len, mask_emb): + if np.random.random() >= self.shift_prob: + return inputs, inputs_len, 0, 0, 0 + + shift_num, shift_num_units, r_shift_num, r_shift_num_units = self.get_shift_num(inputs_len.min()) + + if self.truncate and shift_num > 0 and r_shift_num > 0: + r_shift_num = 0 + r_shift_num_units = 0 + + orig_inputs_t = inputs.shape[1] + + if shift_num_units > 0: + inputs = torch.nn.functional.pad(inputs, (0, 0, shift_num_units, 0)) + inputs[:, :shift_num_units] = mask_emb + inputs_len = inputs_len + shift_num_units + elif shift_num_units < 0: + abs_shift_num_units = abs(shift_num_units) + inputs = inputs[:, abs_shift_num_units:] + inputs_len = inputs_len - abs_shift_num_units + + if r_shift_num_units > 0: + inputs = torch.nn.functional.pad(inputs, (0, 0, 0, r_shift_num_units)) + shift_padding_mask = create_shift_padding_mask(inputs_len, inputs.shape[1], r_shift_num_units) + inputs[shift_padding_mask] = mask_emb + inputs_len = inputs_len + r_shift_num_units + elif r_shift_num_units < 0: + shift_padding_mask = create_shift_padding_mask(inputs_len, inputs.shape[1], r_shift_num_units) + inputs[shift_padding_mask] = 0.0 + abs_shift_num_units = abs(r_shift_num_units) + inputs_len = inputs_len - abs_shift_num_units + inputs = inputs[:, :-abs_shift_num_units] + + inputs_t_diff = inputs.shape[1] - orig_inputs_t + if self.truncate and inputs_t_diff > 0: + truncated_r_shift_num = r_shift_num - int(inputs_t_diff / self.unit) + assert truncated_r_shift_num == -shift_num + inputs = inputs[:, :-inputs_t_diff] + inputs_len = inputs_len - inputs_t_diff + else: + truncated_r_shift_num = r_shift_num + + return inputs, inputs_len, shift_num, r_shift_num, truncated_r_shift_num + + def get_shift_num(self, total_units_num): + if self.dist == 'uniform': + shift_num = np.random.randint(self.min, self.max + 1) + r_shift_num = np.random.randint(self.min, self.max + 1) + else: + shift_num = np.random.normal(loc=self.mean, scale=self.std) + shift_num = int(round(shift_num)) + r_shift_num = np.random.normal(loc=self.mean, scale=self.std) + r_shift_num = int(round(r_shift_num)) + + max_num = int(total_units_num * self.max_ratio / self.unit) + if shift_num > max_num: + if self.truncate: + shift_num = max_num + elif shift_num < -max_num: + shift_num = -max_num + + if r_shift_num < 0: + if shift_num > 0: + r_shift_num = max(-max_num, r_shift_num) + else: + r_shift_num = max(-(max_num - abs(shift_num)), r_shift_num) + + return shift_num, shift_num * self.unit, r_shift_num, r_shift_num * self.unit + + +def create_shift_padding_mask(lengths, max_len, shift_num_units): + positions = torch.arange(max_len, device=lengths.device) + positions.expand(len(lengths), max_len) + shift_audio_lengths = lengths + shift_num_units + if shift_num_units > 0: + padding_mask = (positions >= lengths.unsqueeze(1)) & (positions < shift_audio_lengths.unsqueeze(1)) + else: + padding_mask = (positions >= shift_audio_lengths.unsqueeze(1)) & (positions < lengths.unsqueeze(1)) + return padding_mask + + +class FeatST2VecEncoder(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + self.wav2spec = Serialization.from_config_dict(cfg.preprocessor) + + self.feature_encoder = Serialization.from_config_dict(cfg.feature_encoder) + self.target_momentum = cfg.target_momentum + self.target_update_step = 0 + if self.target_momentum > 0: + self.target_feature_encoder = Serialization.from_config_dict(cfg.feature_encoder) + self.target_feature_encoder.load_state_dict(self.feature_encoder.state_dict()) + for p in self.target_feature_encoder.parameters(): + p.requires_grad = False + else: + self.target_feature_encoder = None + + context_net_dim = cfg.context_net.encoder.embedding_dim + self.feature_proj = nn.Linear(self.feature_encoder.output_dim, context_net_dim) + + self.context_net = TransformerEncoder(cfg.context_net) + + self.mask_cfg = cfg.masking + self.target_mask_cfg = cfg.target_masking + + if self.mask_cfg.mask_emb_type == 'zero': + self.mask_emb = 0.0 + else: + assert self.mask_cfg.mask_emb_type == 'gaussian' + num_features = cfg.preprocessor.features + self.register_buffer("mask_emb", torch.tensor(GAUSSIAN_MASK[:num_features])) + + self.n_negatives = cfg.n_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere + self.negatives_from_noisy_features = cfg.negatives_from_noisy_features + + cfg.predictor.input_dim = context_net_dim + self.predictor = None if cfg.predictor is None else Projector(cfg.predictor) + + def forward(self, wavs, wav_lens, *, mask=True, features_only=False, global_step=None) -> tuple: + specs, specs_len = self.wav2spec( + input_signal=wavs, length=wav_lens, + ) + + specs_mask = self._create_padding_mask(specs_len, specs.shape[1]) if mask else None + + if features_only: + unmasked_features = None + padding_mask = None + features_mask = None + else: + unmasked_specs = specs.clone() + if self.training and self.target_mask_cfg is not None: + unmasked_specs = unmasked_specs.transpose(1, 2) + unmasked_specs, _, _ = apply_mask(self.target_mask_cfg, unmasked_specs, specs_mask, self.mask_emb) + unmasked_specs = unmasked_specs.transpose(1, 2) + if self.target_momentum > 0: + assert global_step is not None + if global_step > self.target_update_step: + ema_update(self.target_feature_encoder, self.feature_encoder, self.target_momentum) + self.target_update_step = global_step + target_feature_encoder = self.target_feature_encoder + else: + target_feature_encoder = self.feature_encoder + with torch.no_grad(): + with as_eval(target_feature_encoder): + unmasked_features, unmasked_feature_lens, _ = target_feature_encoder(unmasked_specs, specs_len) + # [B, D, T] => [B, T, D] + unmasked_features = unmasked_features.transpose(1, 2) + + padding_mask = self._create_padding_mask(unmasked_feature_lens, unmasked_features.shape[1]) + features_mask = ~padding_mask + unmasked_features = unmasked_features[features_mask] + # fake batch dim 1 + unmasked_features = unmasked_features.view(1, -1, unmasked_features.size(-1)) + + if mask: + specs = specs.transpose(1, 2) + specs, _, _ = apply_mask(self.mask_cfg, specs, specs_mask, self.mask_emb) + specs = specs.transpose(1, 2) + + features, feature_lens, _ = self.feature_encoder(specs, specs_len) + # [B, D, T] => [B, T, D] + features = features.transpose(1, 2) + + if self.feature_proj is not None: + features = self.feature_proj(features) + + features = self.context_net(features, padding_mask=padding_mask) + + if features_only: + return features, feature_lens + + assert mask + + if self.predictor is not None: + features = self.predictor(features, length=feature_lens) + + features = features[features_mask] + # fake batch dim to 1 + features = features.view(1, -1, features.size(-1)) + + assert not self.negatives_from_everywhere + if self.negatives_from_noisy_features: + sampled_negatives, _ = self.sample_negatives_flat(features, feature_lens.tolist()) + else: + with torch.no_grad(): + sampled_negatives, _ = self.sample_negatives_flat(unmasked_features, feature_lens.tolist()) + + prob_ppl_loss, cur_temp, prob_ppl = None, None, None + return features, unmasked_features, sampled_negatives, padding_mask, prob_ppl_loss, cur_temp, prob_ppl + + def extract_features(self, source, audio_lengths, mask=False): + padding_mask = self._create_padding_mask(audio_lengths, max_len=source.shape[1]) + return self(source=source, padding_mask=padding_mask, mask=mask, features_only=True) + + def remove_pretraining_modules(self): + self.target_feature_encoder = None + self.predictor = None + + def _update_quantizer_temp(self, global_step): + if self.quantize_targets: + self.quantizer.set_num_updates(global_step) + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz, tsz, fsz}" + + if self.n_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten() + + cross_neg_idxs = torch.randint( + low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def sample_negatives_flat(self, y, nums): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + assert bsz == 1 and tsz == sum(nums) # fake batch dim + y = y.view(-1, fsz) # BTC => (BxT)C + + # cross_high = tsz * bsz + + neg_idxs_l = [] + idx_start = 0 + with torch.no_grad(): + for i, num_i in enumerate(nums): + assert num_i > 1, f"{bsz, tsz, fsz}" + + assert self.n_negatives > 0 + tszs_i = buffered_arange(num_i).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + high_i = num_i + neg_idxs_i = torch.randint(low=0, high=high_i - 1, size=(self.n_negatives * num_i,)) + neg_idxs_i[neg_idxs_i >= tszs_i] += 1 + + neg_idxs_i += idx_start + idx_start += num_i + + neg_idxs_l.append(neg_idxs_i) + + assert self.cross_sample_negatives == 0 + + neg_idxs = torch.cat(neg_idxs_l) + assert neg_idxs.ndim == 1 + + negs = y[neg_idxs] + negs = negs.view(bsz, sum(nums), self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def _create_padding_mask(self, audio_lengths, max_len): + # Broadcast to vectorize creating the padding mask + padding_mask = torch.arange(max_len, device=audio_lengths.device) + padding_mask = padding_mask.expand(len(audio_lengths), max_len) < audio_lengths.unsqueeze(1) + # Negate to false where no padding + padding_mask = ~padding_mask + return padding_mask + + +def apply_mask(mask_cfg, x, padding_mask, mask_emb, mask_positions=None): + B, T, C = x.shape + if mask_cfg.mask_prob > 0: + mask_indices, mask_num = compute_mask_indices( + (B, T), + padding_mask, + mask_cfg.mask_prob, + mask_cfg.mask_length, + mask_cfg.mask_type, + mask_cfg.mask_other, + min_masks=2, + no_overlap=mask_cfg.no_mask_overlap, + min_space=mask_cfg.mask_min_space, + shrink_to_batch_min=mask_cfg.mask_shrink_to_batch_min, + mask_positions=mask_positions + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + if isinstance(mask_emb, torch.Tensor): + mask_emb = mask_emb.type_as(x) + x[mask_indices] = mask_emb + assert len(mask_num) == B + else: + mask_indices = None + mask_num = None + + if mask_cfg.mask_channel_prob > 0: + # assert mask_cfg.mask_shrink_to_batch_min + mask_channel_indices, _ = compute_mask_indices( + (B, C), + None, + mask_cfg.mask_channel_prob, + mask_cfg.mask_channel_length, + mask_cfg.mask_channel_type, + mask_cfg.mask_channel_other, + no_overlap=mask_cfg.no_mask_channel_overlap, + min_space=mask_cfg.mask_channel_min_space, + shrink_to_batch_min=mask_cfg.mask_channel_shrink_to_batch_min, + ) + mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) + x[mask_channel_indices] = 0 + + return x, mask_indices, mask_num + + +def ema_update(ema_module, new_module, m): + with torch.no_grad(): + for param_q, param_k in zip(new_module.parameters(), ema_module.parameters()): + param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) + + +@contextlib.contextmanager +def as_eval(*modules): + training_states = [] + for module_i in modules: + training_states.append(module_i.training) + module_i.eval() + + try: + yield + finally: + for module_i, training_state_i in zip(modules, training_states): + module_i.train(training_state_i) + + +def momentum_scheduler(base_value, final_value, max_steps, *, type): + if type == 'linear': + def linear_scheduler(step): + if step <= max_steps: + cur_value = base_value + (final_value - base_value) * (step / max_steps) + else: + cur_value = final_value + return cur_value + + return linear_scheduler + elif type == 'cosine': + def cosine_scheduler(step): + if step <= max_steps: + cur_value = final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * step / max_steps)) + else: + cur_value = final_value + return cur_value + + return cosine_scheduler + else: + raise ValueError('unknown scheduler type: {}'.format(type)) + + +def get_state_dict( + module_param_prefix, + checkpoint_path, + *, + map_location=None, + strict: bool = True): + from pytorch_lightning.utilities.cloud_io import load as pl_load + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # load the state_dict on the model automatically + if module_param_prefix is not None: + module_state_dict = {k[len(module_param_prefix):]: v for k, v in checkpoint['state_dict'].items() if + k.startswith(module_param_prefix)} + else: + module_state_dict = checkpoint['state_dict'] + return module_state_dict + + +class Reconstrutor(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + self.use_teacher_feat_prob = cfg.use_teacher_feat_prob + + self.quantizer = GumbelVectorQuantizer( + dim=cfg.reconstructor.input_dim, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=cfg.quantizer.latent_dim, + time_first=True, + ) + cfg.reconstructor.input_dim = cfg.quantizer.latent_dim + self.decoder = Projector(cfg.reconstructor) + self.reduction_rate = cfg.reduction_rate + self.loss_type = cfg.loss_type + + if cfg.global_cond_extractor is not None: + self.global_cond_extractor = Projector(cfg.global_cond_extractor) + else: + self.global_cond_extractor = None + + def _update_quantizer_temp(self, global_step): + self.quantizer.set_num_updates(global_step) + + def forward(self, h, h_len, t_h, t_h_len, target, target_len, global_step): + use_teacher_feat = False + if self.training and self.use_teacher_feat_prob > 0 and np.random.random() < self.use_teacher_feat_prob: + h = t_h + h_len = t_h_len + use_teacher_feat = True + + # [B, D, T] => [B, T, D] + target = target.transpose(1, 2) + + if self.global_cond_extractor is not None: + global_cond = self.get_global_cond(target, target_len) + else: + global_cond = None + + self.quantizer.set_num_updates(global_step) + h, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(h) + + if self.global_cond_extractor is not None: + h = h + global_cond.unsqueeze(1) + + pred = self.decoder(h, length=h_len) + + B, T, C = pred.size() + assert C // self.reduction_rate == target.shape[2] + pred = pred.reshape(B, T * self.reduction_rate, C // self.reduction_rate) + pred_lens = h_len * self.reduction_rate + + recon = pred + recon_lens = pred_lens + sample = recon, recon_lens, target, target_len + + import random + if random.random() < 0.0005: + torch.set_printoptions(profile="full") + print('\n use teacher feat: {}'.format(use_teacher_feat)) + print('tgtt:', target[0][180:260]) + print('pred:', pred[0][180:260]) + torch.set_printoptions(profile="default") # reset + + if pred.shape[1] > target.shape[1]: + pred = pred.narrow(dim=1, start=0, length=target.shape[1]).contiguous() + elif pred.shape[1] < target.shape[1]: + target = target.narrow(dim=1, start=0, length=pred.shape[1]).contiguous() + + loss_mask = create_padding_mask(pred_lens, pred.shape[1]) + loss_mask = ~loss_mask + + pred = pred[loss_mask] + target = target[loss_mask] + mean = True + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + return {'loss': loss, + 'recon': sample, + 'quant_ppl_loss': prob_ppl_loss, + 'quant_temp': cur_temp, + 'quant_ppl': prob_ppl} + + def get_quantized_feat(self, feat, feat_len): + q, q_ids = self.quantizer(feat, quant_only=True) + return q, q_ids + + def get_global_cond(self, input, input_len): + cond = self.global_cond_extractor(input, length=input_len) + pad_mask = create_padding_mask(input_len, input.shape[1]) + cond = cond.masked_fill(pad_mask.unsqueeze(2), 0.0) + cond = torch.sum(cond, axis=1) + cond = cond / input_len.unsqueeze(1) + return cond + + +def stop_grad(module): + for p in module.parameters(): + p.requires_grad = False diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_pretrain.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..ea817fe92236ab5c6915fe5f3b159f7c736a8707 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/st2vec/st2vec_pretrain.py @@ -0,0 +1,296 @@ +import logging +from math import ceil +from typing import Dict, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.similarityloss import NegativeCosineSimilarityLoss +from nemo.collections.asr.losses.wav2vecloss import Wav2VecLoss +from nemo.collections.asr.models.st2vec.st2vec_model import ST2VecEncoder, FeatST2VecEncoder +from nemo.collections.asr.parts.perturb import process_augmentations, RandomNoisePerturbation, AudioAugmentor +from nemo.core import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +import pickle +import numpy as np + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +class ST2VecPretrainModel(ModelPT): + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + super().__init__(cfg=cfg, trainer=trainer) + + if isinstance(cfg, dict): + cfg = OmegaConf.create(cfg) + elif not isinstance(cfg, DictConfig): + raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig") + + if cfg.encoder_type == 'st': + self.st2vec_encoder = ST2VecEncoder(cfg.st2vec_encoder) + elif cfg.encoder_type == 'feat_st': + self.st2vec_encoder = FeatST2VecEncoder(cfg.st2vec_encoder) + else: + raise ValueError('unknown encoder type: {}'.format(cfg.encoder_type)) + + self.loss_type = cfg.loss_type + if self.loss_type == 'neg_cos_sim': + self.loss = NegativeCosineSimilarityLoss() + else: + assert self.loss_type == 'wav2vec' + self.loss = Wav2VecLoss( + feature_loss_weight=0.0, + prob_ppl_weight=cfg.loss.prob_ppl_weight, + logit_temp=cfg.logit_temp, + ) + + self.pitch_loss_weight = cfg.st2vec_encoder.pitch_loss_weight + self.reconstruction_loss_weight = cfg.st2vec_encoder.reconstruction_loss_weight + self.reconstruction_quant_ppl_loss_weight = cfg.st2vec_encoder.reconstruction_quant_ppl_loss_weight + + self._prev_log_step = -1 + + def training_step(self, batch, batch_idx): + loss, contrastive_loss, prob_ppl_loss, cur_temp, prob_ppl, _, pitch_loss, recon = self._step(batch) + + if self.global_step > self._prev_log_step: + self._prev_log_step = self.global_step + tensorboard = self.logger.experiment + tensorboard.add_scalar('loss', loss, self.global_step) + if prob_ppl_loss is not None: + tensorboard.add_scalar('contrastive_loss', contrastive_loss, self.global_step) + tensorboard.add_scalar('prob_ppl_loss', prob_ppl_loss, self.global_step) + tensorboard.add_scalar('temp', cur_temp, self.global_step) + tensorboard.add_scalar('prob_ppl', prob_ppl, self.global_step) + if self.pitch_loss_weight: + tensorboard.add_scalar('pitch_loss', pitch_loss, self.global_step) + if self.reconstruction_loss_weight: + tensorboard.add_scalar('contrastive_loss', contrastive_loss, self.global_step) + tensorboard.add_scalar('recon_loss', recon['loss'], self.global_step) + tensorboard.add_scalar('recon_quant_ppl_loss', recon['quant_ppl_loss'], self.global_step) + tensorboard.add_scalar('recon_quant_ppl', recon['quant_ppl'], self.global_step) + tensorboard.add_scalar('recon_quant_temp', recon['quant_temp'], self.global_step) + tensorboard.add_scalar('learning_rate', self._optimizer.param_groups[0]['lr'], self.global_step) + return {'loss': loss} + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + loss, contrastive_loss, prob_ppl_loss, _, prob_ppl, accuracy, pitch_loss, recon = self._step(batch) + self.log('val_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) + if prob_ppl is not None: + self.log('val_contrastive_loss', contrastive_loss, prog_bar=False, on_step=False, on_epoch=True, sync_dist=False) + self.log('val_prob_ppl', prob_ppl, prog_bar=False, on_step=False, on_epoch=True, sync_dist=False) + if accuracy is not None: + self.log('val_accuracy', accuracy, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False) + if self.pitch_loss_weight: + self.log('val_pitch_loss', pitch_loss, prog_bar=False, on_step=False, on_epoch=True, sync_dist=False) + if self.reconstruction_loss_weight: + self.log('val_contrastive_loss', contrastive_loss, prog_bar=False, on_step=False, on_epoch=True, sync_dist=False) + self.log('val_recon_loss', recon['loss'], prog_bar=False, on_step=False, on_epoch=True, sync_dist=False) + self.log('val_recon_quant_ppl_loss', recon['quant_ppl_loss'], prog_bar=False, on_step=False, on_epoch=True, sync_dist=False) + self.log('val_recon_quant_ppl', recon['quant_ppl'], prog_bar=False, on_step=False, on_epoch=True, sync_dist=False) + + def test_step(self, batch, batch_idx, dataloader_idx=0): + loss, contrastive_loss, prob_ppl_loss, _, _, accuracy, _, recon = self._step(batch) + self.log('test_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) + if accuracy is not None: + self.log('test_accuracy', accuracy, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False) + if recon is not None and batch_idx < 4: + recon, recon_lens, target, target_len = recon['recon'] + torch.set_printoptions(profile="full") + print('\n recon sample: {}'.format(batch_idx)) + print('tgtt:', target[0][120:240]) + print('pred:', recon[0][120:240]) + torch.set_printoptions(profile="default") # reset + + def _step(self, batch): + if len(batch) == 4: + audio_signal, audio_lengths, p_audio_signal, p_audio_lengths = batch + else: + audio_signal, audio_lengths = batch + p_audio_signal, p_audio_lengths = None, None + + logits, targets, sampled_negatives, _, prob_ppl_loss, cur_temp, prob_ppl, pitch_loss, recon = self( + source=audio_signal, source_lens=audio_lengths, p_source=p_audio_signal, p_source_lens=p_audio_lengths + ) + if self.loss_type == 'neg_cos_sim': + loss = self.loss(predictions=logits, targets=targets) + contrastive_loss, prob_ppl_loss, accuracy = None, None, None + else: + assert self.loss_type == 'wav2vec' + loss, contrastive_loss, _, prob_ppl_loss, accuracy = self.loss( + logits=logits, + targets=targets, + negatives=sampled_negatives, + prob_ppl_loss=prob_ppl_loss, + feature_loss=None, + compute_accuracy=not self.training + ) + + if self.pitch_loss_weight: + loss = loss + self.pitch_loss_weight * pitch_loss + + if self.reconstruction_loss_weight: + loss = loss + self.reconstruction_loss_weight * recon['loss'] + self.reconstruction_quant_ppl_loss_weight * recon['quant_ppl_loss'] + + return loss, contrastive_loss, prob_ppl_loss, cur_temp, prob_ppl, accuracy, pitch_loss, recon + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + return None + + def forward(self, source, source_lens, p_source, p_source_lens, mask=True, features_only=False) -> tuple: + return self.st2vec_encoder(source, source_lens, p_source, p_source_lens, mask=mask, features_only=features_only, + global_step=self.global_step) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config, noise_perturb_config=self._cfg['noise_perturb']) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, noise_perturb_config=None) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config, noise_perturb_config=None) + + def _setup_dataloader_from_config(self, config: Optional[Dict], noise_perturb_config): + + if noise_perturb_config is not None: + noise_perturb = RandomNoisePerturbation(**noise_perturb_config) + augmentor = AudioAugmentor(perturbations=[(1.0, noise_perturb)]) + return_both = True + else: + augmentor = None + return_both = False + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_text_dataset.get_audio_dataset(config=config, augmentor=augmentor, return_both=return_both) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + @torch.no_grad() + def extract_feature(self, output_dir): + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + preprocessor = self.st2vec_encoder.wav2spec + pad_to_value = preprocessor.featurizer.pad_to + + extracted_feat = [] + extracted_feat_cnt = 1 + feat_file = None + if self.st2vec_encoder.reconstruction_cfg: + feat_path = output_dir / 'feat.txt' + feat_file = open(feat_path, 'w') + + try: + # preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + # Work in tmp directory - will store manifest file there + start = extracted_feat_cnt + + for test_batch in self._test_dl: + + batch = [d_i.to(device) for d_i in test_batch] + + if len(batch) == 4: + audio_signal, audio_lengths, _, _ = batch + else: + audio_signal, audio_lengths = batch + + feat, feat_len = self.st2vec_encoder(audio_signal, audio_lengths, None, None, mask=False, + features_only=True, global_step=self.global_step) + +# res = {'feat': feat.detach().cpu().numpy(), +# 'feat_len': feat_len.detach().cpu().numpy()} + if self.st2vec_encoder.reconstruction_cfg: + q_feat, q_feat_ids = self.st2vec_encoder.reconstructor.get_quantized_feat(feat, feat_len) + q_feat = q_feat.detach().cpu().numpy() + q_feat_ids = q_feat_ids.detach().cpu().numpy() + feat_len = feat_len.detach().cpu().numpy() + #print("=========", np.shape(q_feat_ids), np.shape(feat_len)) + for i in range(len(q_feat_ids)): + feat_str = " ".join([str(x[0]) for x in q_feat_ids[i][:feat_len[i]]]) + feat_file.write(f"{feat_str}\n") + else: + res = {'feat': feat.detach().cpu().numpy(), + 'feat_len': feat_len.detach().cpu().numpy()} + extracted_feat.append({res}) + if extracted_feat_cnt % 200 == 0 or extracted_feat_cnt == len(self._test_dl) : + feat_fp = output_dir / 'feat_{}-{}.pkl'.format(start, extracted_feat_cnt) + with feat_fp.open(mode='wb') as output_file: + print('save features to: {}'.format(feat_fp)) + pickle.dump(extracted_feat, output_file) + extracted_feat = [] # clear the list + start = extracted_feat_cnt + 1 # set the chunk start index + extracted_feat_cnt += 1 + # + print('extract feat: {}/{}'.format(extracted_feat_cnt, len(self._test_dl))) + if feat_file: + feat_file.close() + finally: + # set mode back to its original value + self.train(mode=mode) + preprocessor.featurizer.pad_to = pad_to_value + return extracted_feat diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/transducer_model.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/transducer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2413fe9d91ac2011e457ee033100096c2423bee1 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/transducer_model.py @@ -0,0 +1,41 @@ +import torch + +from htrain.util import instantiate_from_config + + +class TransducerModel(torch.nn.Module): + + def __init__(self, encoder, decoder, joint): + super().__init__() + + self.encoder = encoder + self.decoder = decoder + self.joint = joint + assert not self.joint.fuse_loss_wer + + @classmethod + def from_cfg(cls, cfg): + encoder = instantiate_from_config(cfg.encoder) + decoder = instantiate_from_config(cfg.decoder) + joint = instantiate_from_config(cfg.joint) + return cls(encoder, decoder, joint) + + def forward(self, signal, signal_len, transcript, transcript_len): + encoded, encoded_len, extra = self.encoder(audio_signal=signal, length=signal_len) + + decoder_h, target_length = self.decoder(targets=transcript, target_length=transcript_len) + + if not self.joint.fuse_loss_wer: + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder_h) + return joint, target_length, encoded, encoded_len + else: + # Fused joint step + loss_value, _, _, _ = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoder_h, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=transcript_len, + compute_wer=False, + ) + return loss_value diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01285dd066967eb68fe5a152228ef14c1d50c58c --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from nemo.collections.asr.models.wav2vec.wav2vec_model import Wav2VecEncoderModel diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/wav2vec_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/wav2vec_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b25c1f9d8153ad6eaeef0afa6588461b45d2b0 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/wav2vec_config.py @@ -0,0 +1,201 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Any, Dict, List, Optional + + +class Wav2VecActivationType(Enum): + relu = 'relu' + gelu = 'gelu' + + +class Wav2VecMaskType(Enum): + """ + Used to select configuration to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + """ + + static = auto() + uniform = auto() + normal = auto() + poisson = auto() + + +class Wav2VecConvExtractorMode(Enum): + """ + Mode for feature extractor. default has a single group norm with d groups in the first conv block, + whereas layer_norm has layer norms in every block. + """ + + default = auto() + layer_norm = auto() + + +@dataclass +class ConvConfig: + conv_pos: int = field(default=128, metadata={'help': 'Number of filters for convolutional positional embeddings'}) + conv_pos_groups: int = field( + default=16, metadata={'help': 'Number of groups for convolutional positional embeddings'} + ) + layer_drop: float = 0.0 + + +@dataclass +class Wav2VecTransformerEncoderConfig: + encoder_layers: int = field(default=12, metadata={'help': 'Number of encoder layers in transformer model'}) + encoder_layerdrop: float = field(default=0.05, metadata={'help': 'Probability of dropping transformer layers'}) + embedding_dim: int = field(default=768, metadata={'help': 'Encoder embedding dim'}) + ffn_embedding_dim: int = field(default=3072, metadata={'help': 'Encoder embedding dim for feed forward'}) + num_attention_heads: int = field(default=8, metadata={'help': 'Number of encoder attention heads'}) + dropout: float = field(default=0.1, metadata={'help': 'Dropout probability for transformer encoder'}) + attention_dropout: float = field(default=0.1, metadata={'help': 'dropout probability for attention weights'}) + activation_dropout: float = field(default=0.0, metadata={'help': 'dropout probability after activation in FFN'}) + activation_fn: Wav2VecActivationType = field( + default=Wav2VecActivationType.gelu, metadata={'help': 'Activation for transformer'} + ) + layer_norm_first: bool = field(default=False, metadata={'help': 'Apply layer norm first within the transformer'}) + + +@dataclass +class Wav2VecTransformerConfig: + use_pytorch_transformer: bool = True + dropout: float = field(default=0.1, metadata={'help': 'Dropout probability for the transformer'}) + conv: ConvConfig = ConvConfig() + encoder: Wav2VecTransformerEncoderConfig = Wav2VecTransformerEncoderConfig() + + +@dataclass +class QuantizerConfig: + quantize_targets: bool = field(default=True, metadata={'help': 'Use quantized targets'}) + quantize_input: bool = field(default=False, metadata={'help': 'Use quantized inputs'}) + same_quantizer: bool = field(default=False, metadata={'help': 'Use the same quantizer for inputs and targets'}) + targets_bottleneck_dim: Optional[int] = None + targets_bottleneck_act_fn: Optional[str] = None + targets_bottleneck_dropout: float = 0.0 + latent_vars: int = field( + default=320, metadata={'help': 'Number of latent variables in each group of the codebook'} + ) + latent_groups: int = field(default=2, metadata={'help': 'Number of groups within the codebook'}) + latent_dim: int = field( + default=0, + metadata={ + 'help': 'If greater than 0, use dim for latent variables, else infered by final_dim / latent_groups' + }, + ) + latent_temp: tuple = field( + default=(2, 0.5, 0.999995), metadata={'help': 'Quantize temperature (start, stop, decay factor)'} + ) + # # Add by Dehua + # pre_logits_dim: int = field(default=32) + + # For finite scalar quantizer + levels: List = field( + default_factory=lambda: [8, 5, 5, 5], + metadata={'help': 'Quantization level of each dimension'}, + ) + l2_norm: bool = field(default=False, metadata={'help': 'Use L2 Norm before FSQ'}) + batch_norm: bool = field(default=False, metadata={'help': 'Use batch normalization before FSQ'}) + + +@dataclass +class ConvFeatureEncoderConfig: + extractor_mode: Wav2VecConvExtractorMode = field(default=Wav2VecConvExtractorMode.default) + conv_bias: bool = field(default=False, metadata={'help': 'Include bias in convolution feature extractor model'}) + conv_feature_layers: List = field( + default_factory=lambda: [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] + [(512, 2, 2)], + metadata={'help': 'convolutional feature extraction layers [(dim, kernel_size, stride), ...'}, + ) + + +@dataclass +class LossConfig: + prob_ppl_weight: float = field(default=0.1, metadata={'help': 'Weight applied to quantized prob perplexity loss'}) + feature_loss_weight: float = field(default=0, metadata={'help': 'Weight applied to feature L2 Norm'}) + + +@dataclass +class Wav2VecMaskingConfig: + mask_prob: float = field(default=0.65, metadata={'help': 'Probability of replacing token with mask'}) + mask_type: Wav2VecMaskType = field(default=Wav2VecMaskType.static,) + mask_emb_type: str = 'zero' + mask_other: int = field( + default=0, + metadata={'help': 'Secondary mask used for complex distributions (see help in compute_mask_indices)'}, + ) + mask_length: int = field(default=10, metadata={'help': 'Length of mask when masking time steps'}) + no_mask_overlap: bool = field(default=False, metadata={'help': 'Whether to allow masks to overlap'}) + mask_min_space: int = field( + default=1, metadata={'help': 'Minimum space beetween spans (if no overlap is enabled)'} + ) + mask_channel_prob: float = field(default=0, metadata={'help': 'Probability of replacing a feature with 0'}) + mask_channel_type: Wav2VecMaskType = field(default=Wav2VecMaskType.static,) + mask_channel_other: int = field( + default=0, + metadata={ + 'help': 'Secondary mask argument (used for more complex distributions (see help in compute_mask_indices)' + }, + ) + mask_channel_length: int = field(default=10, metadata={'help': 'Length of masks for features (channels)'}) + no_mask_channel_overlap: bool = field( + default=False, metadata={'help': 'Whether to allow channel masks to overlap'} + ) + mask_channel_min_space: int = field( + default=1, metadata={'help': 'Minimum space between spans (if no overlap is enabled)'} + ) + mask_shrink_to_batch_min: bool = True + mask_channel_shrink_to_batch_min: bool = False + + +@dataclass +class Wav2VecEncoderModelConfig: + loss: LossConfig = LossConfig() + quantizer: QuantizerConfig = QuantizerConfig() + conv_feature_encoder: ConvFeatureEncoderConfig = ConvFeatureEncoderConfig() + transformer_encoder: Wav2VecTransformerConfig = Wav2VecTransformerConfig() + masking: Wav2VecMaskingConfig = Wav2VecMaskingConfig() + + dropout_input: float = field(default=0.1, metadata={'help': 'Dropout applied to input raw features'}) + dropout_features: float = field( + default=0.1, metadata={'help': 'Dropout applied to the features generator by convolutions'} + ) + final_dim: int = field(default=0, metadata={'help': 'Project final representations and targets to this dimension'}) + n_negatives: int = field( + default=100, metadata={'help': 'Number of negatives to sample from the same audio sample'} + ) + cross_sample_negatives: int = field( + default=0, metadata={'help': 'Number of negatives to sample from any sample in the batch'} + ) + codebook_negatives: int = field(default=0, metadata={'help': 'Number of negative examples in codebook'}) + negatives_from_everywhere: bool = field( + default=False, metadata={'help': 'Sample negatives from everywhere, not just masked states'} + ) + logit_temp: float = field(default=0.1, metadata={'help': 'Temperature to divide logits by'}) + target_glu: bool = field(default=False, metadata={'help': 'Adds project and applies GLU to targets'}) + feature_grad_mult: float = field(default=0.1, metadata={'help': 'Multiply extracted feature gradients'}) + + train_ds: Optional[Dict[Any, Any]] = None + validation_ds: Optional[Dict[Any, Any]] = None + test_ds: Optional[Dict[Any, Any]] = None + expected_gpu_num: int = 1 + optim: Optional[Dict[Any, Any]] = None diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/wav2vec_model.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/wav2vec_model.py new file mode 100644 index 0000000000000000000000000000000000000000..23734a78ada6532ee0f0abcef0efa9c569cc3ceb --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/models/wav2vec/wav2vec_model.py @@ -0,0 +1,578 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +import logging +from math import ceil +from typing import Dict, Optional, Union + +import torch +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer +from torch import nn + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.losses.wav2vecloss import Wav2VecLoss +from nemo.collections.asr.models.wav2vec.wav2vec_config import Wav2VecEncoderModelConfig +from nemo.collections.asr.modules.wav2vec_modules import GumbelVectorQuantizer, compute_mask_indices +from nemo.collections.asr.parts.perturb import process_augmentations +from nemo.collections.asr.parts.wav2vec import ConvFeatureEncoder, GradMultiply, Wav2VecTransformerEncoder, \ + TransformerEncoder +from nemo.core import ModelPT +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LossType, MaskType, NeuralType +from nemo.core.neural_types.elements import BoolType, FloatType + + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +class Wav2VecEncoderModel(ModelPT): + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + self.global_rank = 0 + self.world_size = 1 + self.local_rank = 0 + if trainer is not None: + self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank + self.world_size = trainer.num_nodes * trainer.num_gpus + self.local_rank = trainer.local_rank + + super().__init__(cfg=cfg, trainer=trainer) + + schema = OmegaConf.structured(Wav2VecEncoderModelConfig) + if isinstance(cfg, dict): + cfg = OmegaConf.create(cfg) + elif not isinstance(cfg, DictConfig): + raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig") + + cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) + cfg = OmegaConf.merge(schema, cfg) + + feature_enc_layers = cfg.conv_feature_encoder.conv_feature_layers + self.embed = feature_enc_layers[-1][0] # Select last conv output layer dimension + + self.feature_extractor = ConvFeatureEncoder( + conv_layers=feature_enc_layers, + mode=cfg.conv_feature_encoder.extractor_mode, + conv_bias=cfg.conv_feature_encoder.conv_bias, + ) + + encoder_embed_dim = cfg.transformer_encoder.encoder.embedding_dim + self.post_extract_proj = ( + nn.Linear(self.embed, encoder_embed_dim) + if self.embed != encoder_embed_dim and not cfg.quantizer.quantize_input + else None + ) + assert not cfg.quantizer.quantize_input # finetune expect this + + self.mask_cfg = cfg.masking + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.quantizer = None + self.input_quantizer = None + + self.n_negatives = cfg.n_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere + + final_dim = cfg.final_dim if cfg.final_dim > 0 else encoder_embed_dim + self.final_dim = final_dim + self.quantize_targets = cfg.quantizer.quantize_targets + if self.quantize_targets: + assert cfg.quantizer.targets_bottleneck_dim is None + vq_dim = cfg.quantizer.latent_dim if cfg.quantizer.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_q = nn.Linear(vq_dim, final_dim) + else: + assert cfg.loss.prob_ppl_weight == 0 + targets_bottleneck_dim = cfg.quantizer.targets_bottleneck_dim + if targets_bottleneck_dim is None: + self.project_q = nn.Linear(self.embed, final_dim) + else: + act_fn_dic = {'relu': nn.ReLU, 'gelu': nn.GELU} + targets_proj_act_fn = cfg.quantizer.targets_bottleneck_act_fn + targets_proj_layers = ( + [nn.Linear(self.embed, targets_bottleneck_dim)] + + ([] if targets_proj_act_fn is None else [act_fn_dic[targets_proj_act_fn]]) + + [nn.Linear(targets_bottleneck_dim, final_dim)] + + ) + self.project_q = torch.nn.Sequential(*targets_proj_layers) + + if cfg.quantizer.quantize_input: + if cfg.quantizer.same_quantizer and self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = cfg.quantizer.latent_dim if cfg.quantizer.latent_dim > 0 else encoder_embed_dim + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_inp = nn.Linear(vq_dim, encoder_embed_dim) + + self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_embed_dim).uniform_()) + + if cfg.transformer_encoder.use_pytorch_transformer: + self.encoder = Wav2VecTransformerEncoder(cfg.transformer_encoder) + else: + self.encoder = TransformerEncoder(cfg.transformer_encoder) + self.layer_norm = nn.LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU()) + + self.final_proj = nn.Linear(encoder_embed_dim, final_dim) + self.loss = Wav2VecLoss( + feature_loss_weight=cfg.loss.feature_loss_weight, + prob_ppl_weight=cfg.loss.prob_ppl_weight, + logit_temp=cfg.logit_temp, + ) + self._prev_log_step = -1 + + def training_step(self, batch, batch_idx): + loss, contrastive_loss, feature_loss, prob_ppl_loss, cur_temp, prob_ppl, _ = self._step(batch) + + if self.global_step > self._prev_log_step: + self._prev_log_step = self.global_step + tensorboard = self.logger.experiment + tensorboard.add_scalar('loss', loss, self.global_step) + tensorboard.add_scalar('contrastive_loss', contrastive_loss, self.global_step) + tensorboard.add_scalar('feature_loss', feature_loss, self.global_step) + if self.quantize_targets: + tensorboard.add_scalar('prob_ppl_loss', prob_ppl_loss, self.global_step) + tensorboard.add_scalar('temp', cur_temp, self.global_step) + tensorboard.add_scalar('prob_ppl', prob_ppl, self.global_step) + tensorboard.add_scalar('learning_rate', self._optimizer.param_groups[0]['lr'], self.global_step) + return {'loss': loss} + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + loss, contrastive_loss, feature_loss, prob_ppl_loss, _, _, accuracy = self._step(batch) + self.log('val_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) + self.log('val_accuracy', accuracy, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False) + + def test_step(self, batch, batch_idx, dataloader_idx=0): + loss, contrastive_loss, feature_loss, prob_ppl_loss, _, _, accuracy = self._step(batch) + self.log('test_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) + self.log('test_accuracy', accuracy, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False) + + def _step(self, batch): + audio_signal, audio_lengths = batch + + self._update_quantizer_temp() + logits, targets, sampled_negatives, _, features_penalty, prob_ppl_loss, cur_temp, prob_ppl = self( + source=audio_signal, source_len=audio_lengths + ) + loss, contrastive_loss, feature_loss, prob_ppl_loss, accuracy = self.loss( + logits=logits, + targets=targets, + negatives=sampled_negatives, + prob_ppl_loss=prob_ppl_loss, + feature_loss=features_penalty, + compute_accuracy=not self.training + ) + return loss, contrastive_loss, feature_loss, prob_ppl_loss, cur_temp, prob_ppl, accuracy + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + return None + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "source": NeuralType(('B', 'T'), AudioSignal()), + "padding_mask": NeuralType(('B', 'T'), MaskType(), optional=True), + "mask": NeuralType(elements_type=BoolType(), optional=True), + "features_only": NeuralType(elements_type=BoolType(), optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "logits": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "targets": NeuralType(('B', 'T', 'D'), EncodedRepresentation(), optional=True), + "sampled_negatives": NeuralType(('N', 'B', 'T', 'D'), EncodedRepresentation(), optional=True), + "padding_mask": NeuralType(('B', 'T'), MaskType(), optional=True), + "features_penalty": NeuralType(elements_type=LossType(), optional=True), + "prob_ppl_loss": NeuralType(elements_type=LossType(), optional=True), + "cur_codebook_temp": NeuralType(elements_type=FloatType(), optional=True), + } + + def forward(self, source, source_len, *, mask=True, features_only=False) -> tuple: + prob_ppl_loss, cur_temp = None, None + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + feature_lens = self.feature_extractor.get_subsampled_lens(source_len) + padding_mask = self._create_padding_mask(feature_lens) + assert feature_lens.max() == features.shape[2] == padding_mask.shape[1] + + features = features.transpose(1, 2) + + features_penalty = features[~padding_mask].float().pow(2).mean() # L2 Norm on features + + features = self.layer_norm(features) + unmasked_features = features.clone() + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + assert self.input_quantizer is None + # if self.input_quantizer: + # features, prob_ppl_loss, cur_codebook_temp = self.input_quantizer(features) + # features = self.project_inp(features) + if mask: + logits, mask_indices, mask_num = self.apply_mask(features, padding_mask) + if mask_indices is not None: + targets = unmasked_features[mask_indices] + if self.mask_cfg.mask_shrink_to_batch_min: + targets = targets.view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + # fake batch dim 1 + targets = targets.view( + 1, -1, unmasked_features.size(-1) + ) + assert targets.shape[1] == sum(mask_num) + else: + targets = unmasked_features + else: + logits = features + targets = unmasked_features + mask_indices = None + mask_num = None + + logits = self.encoder(logits, padding_mask=padding_mask) + + if features_only: + return logits, padding_mask + + if self.quantize_targets: + targets, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(targets) + targets = self.project_q(targets) + + if self.negatives_from_everywhere: + assert self.mask_cfg.mask_shrink_to_batch_min + neg_cands, *_ = self.quantizer(unmasked_features) + sampled_negatives, _ = self.sample_negatives(neg_cands, targets.size(1)) + sampled_negatives = self.project_q(sampled_negatives) + else: + if self.mask_cfg.mask_shrink_to_batch_min: + sampled_negatives, _ = self.sample_negatives(targets, targets.size(1)) + else: + sampled_negatives, _ = self.sample_negatives_flat(targets, mask_num) + + if self.codebook_negatives > 0: + assert self.mask_cfg.mask_shrink_to_batch_min + cb_negs = self.quantizer.sample_from_codebook( + targets.size(0) * targets.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, targets.size(0), targets.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + sampled_negatives = torch.cat([sampled_negatives, cb_negs], dim=0) + else: + targets = self.project_q(targets) + prob_ppl = None + + if self.negatives_from_everywhere: + assert self.mask_cfg.mask_shrink_to_batch_min + sampled_negatives, _ = self.sample_negatives(unmasked_features, targets.size(1)) + sampled_negatives = self.project_q(sampled_negatives) + else: + if self.mask_cfg.mask_shrink_to_batch_min: + sampled_negatives, _ = self.sample_negatives(targets, targets.size(1)) + else: + sampled_negatives, _ = self.sample_negatives_flat(targets, mask_num) + + mask_logits = logits[mask_indices] + if self.mask_cfg.mask_shrink_to_batch_min: + mask_logits = mask_logits.view(logits.size(0), -1, logits.size(-1)) + else: + # fake batch dim to 1 + mask_logits = mask_logits.view(1, -1, logits.size(-1)) + + if self.target_glu: + targets = self.target_glu(targets) + sampled_negatives = self.target_glu(sampled_negatives) + + mask_logits = self.final_proj(mask_logits) + + return mask_logits, targets, sampled_negatives, padding_mask, features_penalty, prob_ppl_loss, cur_temp, prob_ppl + + def extract_features(self, source, audio_lengths, mask=False): + padding_mask = self._create_padding_mask(audio_lengths) + return self(source=source, padding_mask=padding_mask, mask=mask, features_only=True) + + def remove_pretraining_modules(self): + self.quantizer = None + self.project_q = None + self.target_glu = None + self.final_proj = None + + def _update_quantizer_temp(self): + if self.quantizer: + self.quantizer.set_num_updates(self.trainer.global_step) + if self.input_quantizer: + self.input_quantizer.set_num_updates(self.trainer.global_step) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_cfg.mask_prob > 0: + mask_indices, mask_num = compute_mask_indices( + (B, T), + padding_mask, + self.mask_cfg.mask_prob, + self.mask_cfg.mask_length, + self.mask_cfg.mask_type, + self.mask_cfg.mask_other, + min_masks=2, + no_overlap=self.mask_cfg.no_mask_overlap, + min_space=self.mask_cfg.mask_min_space, + shrink_to_batch_min=self.mask_cfg.mask_shrink_to_batch_min, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + mask_emb = self.mask_emb.type_as(x) + x[mask_indices] = mask_emb + else: + mask_indices = None + + if self.mask_cfg.mask_channel_prob > 0: + # assert self.mask_cfg.mask_shrink_to_batch_min + mask_channel_indices, _ = compute_mask_indices( + (B, C), + None, + self.mask_cfg.mask_channel_prob, + self.mask_cfg.mask_channel_length, + self.mask_cfg.mask_channel_type, + self.mask_cfg.mask_channel_other, + no_overlap=self.mask_cfg.no_mask_channel_overlap, + min_space=self.mask_cfg.mask_channel_min_space, + shrink_to_batch_min=self.mask_cfg.mask_channel_shrink_to_batch_min, + ) + mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) + x[mask_channel_indices] = 0 + + assert len(mask_num) == B + return x, mask_indices, mask_num + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz, tsz, fsz}" + + if self.n_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten() + + cross_neg_idxs = torch.randint( + low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def sample_negatives_flat(self, y, nums): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + assert bsz == 1 and tsz == sum(nums) # fake batch dim + y = y.view(-1, fsz) # BTC => (BxT)C + + # cross_high = tsz * bsz + + neg_idxs_l = [] + idx_start = 0 + with torch.no_grad(): + for i, num_i in enumerate(nums): + assert num_i > 1, f"{bsz, tsz, fsz}" + + assert self.n_negatives > 0 + tszs_i = buffered_arange(num_i).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + high_i = num_i + neg_idxs_i = torch.randint(low=0, high=high_i - 1, size=(self.n_negatives * num_i,)) + neg_idxs_i[neg_idxs_i >= tszs_i] += 1 + + neg_idxs_i += idx_start + idx_start += num_i + + neg_idxs_l.append(neg_idxs_i) + + assert self.cross_sample_negatives == 0 + # if self.cross_sample_negatives > 0: + # tszs = buffered_arange(num_i).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten() + # + # cross_neg_idxs = torch.randint( + # low=0, high=cross_high - 1, size=(self.cross_sample_negatives * num_i), + # ) + # cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + # if self.n_negatives <= 0: + # neg_idxs = cross_neg_idxs + + # if self.cross_sample_negatives > 0 and self.n_negatives > 0: + # neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + neg_idxs = torch.cat(neg_idxs_l) + assert neg_idxs.ndim == 1 + + negs = y[neg_idxs] + negs = negs.view(bsz, sum(nums), self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def _create_padding_mask(self, audio_lengths): + # Broadcast to vectorize creating the padding mask + max_len = max(audio_lengths) + padding_mask = torch.arange(max_len, device=audio_lengths.device) + padding_mask = padding_mask.expand(len(audio_lengths), max_len) < audio_lengths.unsqueeze(1) + # Negate to false where no padding + padding_mask = ~padding_mask + return padding_mask + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor']) + else: + augmentor = None + + shuffle = config['shuffle'] + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + dataset = audio_to_text_dataset.get_audio_dataset(config=config, augmentor=augmentor) + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=dataset.collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33393da83c446bbfd6da76f08a3f07b4cc826233 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.modules.audio_preprocessing import ( + AudioToMelSpectrogramPreprocessor, + AudioToMFCCPreprocessor, + CropOrPadSpectrogramAugmentation, + SpectrogramAugmentation, +) +from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM +from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder +from nemo.collections.asr.modules.conv_asr import ( + ConvASRDecoder, + ConvASRDecoderClassification, + ConvASREncoder, + SpeakerDecoder, +) +from nemo.collections.asr.modules.conv_transformer_encoder import ConvTransformerEncoder +from nemo.collections.asr.modules.text_conv_transformer_encoder import TextConvTransformerEncoder +from nemo.collections.asr.modules.lstm_decoder import LSTMDecoder +from nemo.collections.asr.modules.rnnt import RNNTDecoder, RNNTJoint +from nemo.collections.asr.modules.transformer_t_decoder import TransformerTDecoder diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/audio_preprocessing.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/audio_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..e6784b727f39f6c00c86accc3dbe9fbc90d27746 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/audio_preprocessing.py @@ -0,0 +1,597 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from packaging import version + +from nemo.collections.asr.parts.features import FilterbankFeatures +from nemo.collections.asr.parts.spectr_augment import SpecAugment, SpecCutout +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import ( + AudioSignal, + LengthsType, + MelSpectrogramType, + MFCCSpectrogramType, + NeuralType, + SpectrogramType, +) +from nemo.utils import logging + +try: + import torchaudio + import torchaudio.functional + import torchaudio.transforms + + TORCHAUDIO_VERSION = version.parse(torchaudio.__version__) + TORCHAUDIO_VERSION_MIN = version.parse('0.5') + + HAVE_TORCHAUDIO = True +except ModuleNotFoundError: + HAVE_TORCHAUDIO = False + +__all__ = [ + 'AudioToMelSpectrogramPreprocessor', + 'AudioToMFCCPreprocessor', + 'SpectrogramAugmentation', + 'CropOrPadSpectrogramAugmentation', +] + + +class AudioPreprocessor(NeuralModule, ABC): + """ + An interface for Neural Modules that performs audio pre-processing, + transforming the wav files to features. + """ + + def __init__(self, win_length, hop_length): + super().__init__() + + self.win_length = win_length + self.hop_length = hop_length + + self.torch_windows = { + 'hann': torch.hann_window, + 'hamming': torch.hamming_window, + 'blackman': torch.blackman_window, + 'bartlett': torch.bartlett_window, + 'ones': torch.ones, + None: torch.ones, + } + + @typecheck() + @torch.no_grad() + def forward(self, input_signal, length): + processed_signal, processed_length = self.get_features(input_signal, length) + + return processed_signal, processed_length + + @abstractmethod + def get_features(self, input_signal, length): + # Called by forward(). Subclasses should implement this. + pass + + +class AudioToMelSpectrogramPreprocessor(AudioPreprocessor): + """Featurizer module that converts wavs to mel spectrograms. + We don't use torchaudio's implementation here because the original + implementation is not the same, so for the sake of backwards-compatibility + this will use the old FilterbankFeatures for now. + Args: + sample_rate (int): Sample rate of the input audio data. + Defaults to 16000 + window_size (float): Size of window for fft in seconds + Defaults to 0.02 + window_stride (float): Stride of window for fft in seconds + Defaults to 0.01 + n_window_size (int): Size of window for fft in samples + Defaults to None. Use one of window_size or n_window_size. + n_window_stride (int): Stride of window for fft in samples + Defaults to None. Use one of window_stride or n_window_stride. + window (str): Windowing function for fft. can be one of ['hann', + 'hamming', 'blackman', 'bartlett'] + Defaults to "hann" + normalize (str): Can be one of ['per_feature', 'all_features']; all + other options disable feature normalization. 'all_features' + normalizes the entire spectrogram to be mean 0 with std 1. + 'pre_features' normalizes per channel / freq instead. + Defaults to "per_feature" + n_fft (int): Length of FT window. If None, it uses the smallest power + of 2 that is larger than n_window_size. + Defaults to None + preemph (float): Amount of pre emphasis to add to audio. Can be + disabled by passing None. + Defaults to 0.97 + features (int): Number of mel spectrogram freq bins to output. + Defaults to 64 + lowfreq (int): Lower bound on mel basis in Hz. + Defaults to 0 + highfreq (int): Lower bound on mel basis in Hz. + Defaults to None + log (bool): Log features. + Defaults to True + log_zero_guard_type(str): Need to avoid taking the log of zero. There + are two options: "add" or "clamp". + Defaults to "add". + log_zero_guard_value(float, or str): Add or clamp requires the number + to add with or clamp to. log_zero_guard_value can either be a float + or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is + passed. + Defaults to 2**-24. + dither (float): Amount of white-noise dithering. + Defaults to 1e-5 + pad_to (int): Ensures that the output size of the time dimension is + a multiple of pad_to. + Defaults to 16 + frame_splicing (int): Defaults to 1 + stft_exact_pad (bool): If True, uses pytorch_stft and convolutions with + padding such that num_frames = num_samples / hop_length. If False, + stft_conv will be used to determine how stft will be performed. + Defaults to False + stft_conv (bool): If True, uses pytorch_stft and convolutions. If + False, uses torch.stft. + Defaults to False + pad_value (float): The value that shorter mels are padded with. + Defaults to 0 + mag_power (float): The power that the linear spectrogram is raised to + prior to multiplication with mel basis. + Defaults to 2 for a power spec + """ + + def save_to(self, save_path: str): + pass + + @classmethod + def restore_from(cls, restore_path: str): + pass + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)), + "length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + processed_signal: + 0: AxisType(BatchTag) + 1: AxisType(MelSpectrogramSignalTag) + 2: AxisType(ProcessedTimeTag) + processed_length: + 0: AxisType(BatchTag) + """ + return { + "processed_signal": NeuralType(('B', 'D', 'T'), MelSpectrogramType()), + "processed_length": NeuralType(tuple('B'), LengthsType()), + } + + def __init__( + self, + sample_rate=16000, + window_size=0.02, + window_stride=0.01, + n_window_size=None, + n_window_stride=None, + window="hann", + normalize="per_feature", + n_fft=None, + preemph=0.97, + features=64, + lowfreq=0, + highfreq=None, + log=True, + log_zero_guard_type="add", + log_zero_guard_value=2 ** -24, + dither=1e-5, + dither_train_only=False, + pad_to=16, + frame_splicing=1, + stft_exact_pad=False, + stft_conv=False, + pad_value=0, + mag_power=2.0, + normalize_time_domain=False, + ): + super().__init__(n_window_size, n_window_stride) + + self._sample_rate = sample_rate + if window_size and n_window_size: + raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.") + if window_stride and n_window_stride: + raise ValueError( + f"{self} received both window_stride and " f"n_window_stride. Only one should be specified." + ) + if window_size: + n_window_size = int(window_size * self._sample_rate) + if window_stride: + n_window_stride = int(window_stride * self._sample_rate) + + self.featurizer = FilterbankFeatures( + sample_rate=self._sample_rate, + n_window_size=n_window_size, + n_window_stride=n_window_stride, + window=window, + normalize=normalize, + n_fft=n_fft, + preemph=preemph, + nfilt=features, + lowfreq=lowfreq, + highfreq=highfreq, + log=log, + log_zero_guard_type=log_zero_guard_type, + log_zero_guard_value=log_zero_guard_value, + dither=dither, + dither_train_only=dither_train_only, + pad_to=pad_to, + frame_splicing=frame_splicing, + stft_exact_pad=stft_exact_pad, + stft_conv=stft_conv, + pad_value=pad_value, + mag_power=mag_power, + normalize_time_domain=normalize_time_domain, + ) + + def get_features(self, input_signal, length): + return self.featurizer(input_signal, length) + + @property + def filter_banks(self): + return self.featurizer.filter_banks + + +class AudioToMFCCPreprocessor(AudioPreprocessor): + """Preprocessor that converts wavs to MFCCs. + Uses torchaudio.transforms.MFCC. + Args: + sample_rate: The sample rate of the audio. + Defaults to 16000. + window_size: Size of window for fft in seconds. Used to calculate the + win_length arg for mel spectrogram. + Defaults to 0.02 + window_stride: Stride of window for fft in seconds. Used to caculate + the hop_length arg for mel spect. + Defaults to 0.01 + n_window_size: Size of window for fft in samples + Defaults to None. Use one of window_size or n_window_size. + n_window_stride: Stride of window for fft in samples + Defaults to None. Use one of window_stride or n_window_stride. + window: Windowing function for fft. can be one of ['hann', + 'hamming', 'blackman', 'bartlett', 'none', 'null']. + Defaults to 'hann' + n_fft: Length of FT window. If None, it uses the smallest power of 2 + that is larger than n_window_size. + Defaults to None + lowfreq (int): Lower bound on mel basis in Hz. + Defaults to 0 + highfreq (int): Lower bound on mel basis in Hz. + Defaults to None + n_mels: Number of mel filterbanks. + Defaults to 64 + n_mfcc: Number of coefficients to retain + Defaults to 64 + dct_type: Type of discrete cosine transform to use + norm: Type of norm to use + log: Whether to use log-mel spectrograms instead of db-scaled. + Defaults to True. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "input_signal": NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)), + "length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return { + "processed_signal": NeuralType(('B', 'D', 'T'), MFCCSpectrogramType()), + "processed_length": NeuralType(tuple('B'), LengthsType()), + } + + def save_to(self, save_path: str): + pass + + @classmethod + def restore_from(cls, restore_path: str): + pass + + def __init__( + self, + sample_rate=16000, + window_size=0.02, + window_stride=0.01, + n_window_size=None, + n_window_stride=None, + window='hann', + n_fft=None, + lowfreq=0.0, + highfreq=None, + n_mels=64, + n_mfcc=64, + dct_type=2, + norm='ortho', + log=True, + ): + self._sample_rate = sample_rate + if not HAVE_TORCHAUDIO: + logging.error('Could not import torchaudio. Some features might not work.') + + raise ModuleNotFoundError( + "torchaudio is not installed but is necessary for " + "AudioToMFCCPreprocessor. We recommend you try " + "building it from source for the PyTorch version you have." + ) + if window_size and n_window_size: + raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.") + if window_stride and n_window_stride: + raise ValueError( + f"{self} received both window_stride and " f"n_window_stride. Only one should be specified." + ) + # Get win_length (n_window_size) and hop_length (n_window_stride) + if window_size: + n_window_size = int(window_size * self._sample_rate) + if window_stride: + n_window_stride = int(window_stride * self._sample_rate) + + super().__init__(n_window_size, n_window_stride) + + mel_kwargs = {} + + mel_kwargs['f_min'] = lowfreq + mel_kwargs['f_max'] = highfreq + mel_kwargs['n_mels'] = n_mels + + mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size)) + + mel_kwargs['win_length'] = n_window_size + mel_kwargs['hop_length'] = n_window_stride + + # Set window_fn. None defaults to torch.ones. + window_fn = self.torch_windows.get(window, None) + if window_fn is None: + raise ValueError( + f"Window argument for AudioProcessor is invalid: {window}." + f"For no window function, use 'ones' or None." + ) + mel_kwargs['window_fn'] = window_fn + + # Use torchaudio's implementation of MFCCs as featurizer + self.featurizer = torchaudio.transforms.MFCC( + sample_rate=self._sample_rate, + n_mfcc=n_mfcc, + dct_type=dct_type, + norm=norm, + log_mels=log, + melkwargs=mel_kwargs, + ) + + def get_features(self, input_signal, length): + features = self.featurizer(input_signal) + seq_len = torch.ceil(length.to(torch.float32) / self.hop_length).to(dtype=torch.long) + return features, seq_len + + +class SpectrogramAugmentation(NeuralModule): + """ + Performs time and freq cuts in one of two ways. + SpecAugment zeroes out vertical and horizontal sections as described in + SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with + SpecAugment are `freq_masks`, `time_masks`, `freq_width`, and `time_width`. + SpecCutout zeroes out rectangulars as described in Cutout + (https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are + `rect_masks`, `rect_freq`, and `rect_time`. + Args: + freq_masks (int): how many frequency segments should be cut. + Defaults to 0. + time_masks (int): how many time segments should be cut + Defaults to 0. + freq_width (int): maximum number of frequencies to be cut in one + segment. + Defaults to 10. + time_width (int): maximum number of time steps to be cut in one + segment + Defaults to 10. + rect_masks (int): how many rectangular masks should be cut + Defaults to 0. + rect_freq (int): maximum size of cut rectangles along the frequency + dimension + Defaults to 5. + rect_time (int): maximum size of cut rectangles along the time + dimension + Defaults to 25. + """ + + def save_to(self, save_path: str): + pass + + @classmethod + def restore_from(cls, restore_path: str): + pass + + @property + def input_types(self): + """Returns definitions of module input types + """ + return {"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType())} + + @property + def output_types(self): + """Returns definitions of module output types + """ + return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} + + def __init__( + self, + freq_masks=0, + time_masks=0, + max_time_masks=20, + freq_width=10, + time_width=10, + rect_masks=0, + rect_time=5, + rect_freq=20, + gauss_mask_std=0.0, + rng=None, + ): + super().__init__() + + if rect_masks > 0: + self.spec_cutout = SpecCutout(rect_masks=rect_masks, rect_time=rect_time, rect_freq=rect_freq, rng=rng,) + # self.spec_cutout.to(self._device) + else: + self.spec_cutout = lambda x: x + + if freq_masks + time_masks > 0: + self.spec_augment = SpecAugment( + freq_masks=freq_masks, time_masks=time_masks, freq_width=freq_width, time_width=time_width, + max_time_masks=max_time_masks, gauss_mask_std=gauss_mask_std, rng=rng, + ) + else: + self.spec_augment = lambda x, l: x + + @typecheck() + def forward(self, input_spec, length): + augmented_spec = self.spec_cutout(input_spec) + augmented_spec = self.spec_augment(augmented_spec, length) + return augmented_spec + + +class CropOrPadSpectrogramAugmentation(NeuralModule): + """ + Pad or Crop the incoming Spectrogram to a certain shape. + Args: + audio_length (int): the final number of timesteps that is required. + The signal will be either padded or cropped temporally to this + size. + """ + + def __init__(self, audio_length): + super(CropOrPadSpectrogramAugmentation, self).__init__() + self.audio_length = audio_length + + @typecheck() + @torch.no_grad() + def forward(self, input_signal, length): + image = input_signal + num_images = image.shape[0] + + audio_length = self.audio_length + image_len = image.shape[-1] + + # Crop long signal + if image_len > audio_length: # randomly slice + cutout_images = [] + offset = torch.randint(low=0, high=image_len - audio_length + 1, size=[num_images]) + + for idx, offset in enumerate(offset): + cutout_images.append(image[idx : idx + 1, :, offset : offset + audio_length]) + + image = torch.cat(cutout_images, dim=0) + del cutout_images + + else: # symmetrically pad short signal with zeros + pad_left = (audio_length - image_len) // 2 + pad_right = (audio_length - image_len) // 2 + + if (audio_length - image_len) % 2 == 1: + pad_right += 1 + + image = torch.nn.functional.pad(image, [pad_left, pad_right], mode="constant", value=0) + + # Replace dynamic length sequences with static number of timesteps + length = (length * 0) + audio_length + + return image, length + + @property + def input_types(self): + """Returns definitions of module output ports. + """ + return { + "input_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return { + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "processed_length": NeuralType(tuple('B'), LengthsType()), + } + + def save_to(self, save_path: str): + pass + + @classmethod + def restore_from(cls, restore_path: str): + pass + + +@dataclass +class AudioToMelSpectrogramPreprocessorConfig: + _target_: str = "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor" + sample_rate: int = 16000 + window_size: float = 0.02 + window_stride: float = 0.01 + n_window_size: Optional[int] = None + n_window_stride: Optional[int] = None + window: str = "hann" + normalize: str = "per_feature" + n_fft: Optional[int] = None + preemph: float = 0.97 + features: int = 64 + lowfreq: int = 0 + highfreq: Optional[int] = None + log: bool = True + log_zero_guard_type: str = "add" + log_zero_guard_value: float = 2 ** -24 + dither: float = 1e-5 + dither_train_only: bool = False + pad_to: int = 16 + frame_splicing: int = 1 + stft_exact_pad: bool = False + stft_conv: bool = False + pad_value: int = 0 + mag_power: float = 2.0 + normalize_time_domain: bool = False + + +@dataclass +class SpectrogramAugmentationConfig: + _target_: str = "nemo.collections.asr.modules.SpectrogramAugmentation" + freq_masks: int = 0 + time_masks: Any = 0 + max_time_masks: int = 20 + freq_width: int = 0 + time_width: Optional[Any] = 0 + rect_masks: int = 0 + rect_time: int = 0 + rect_freq: int = 0 + gauss_mask_std: float = 0.0 diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/beam_search_decoder.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/beam_search_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e36612b856200f6c03ed10988fb9405e74cceebd --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/beam_search_decoder.py @@ -0,0 +1,104 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import LengthsType, LogprobsType, NeuralType, PredictionsType + + +class BeamSearchDecoderWithLM(NeuralModule): + """Neural Module that does CTC beam search with a N-gram language model. + It takes a batch of log_probabilities. Note the bigger the batch, the + better as processing is parallelized. Outputs a list of size batch_size. + Each element in the list is a list of size beam_search, and each element + in that list is a tuple of (final_log_prob, hyp_string). + Args: + vocab (list): List of characters that can be output by the ASR model. For English, this is the 28 character set + {a-z '}. The CTC blank symbol is automatically added. + beam_width (int): Size of beams to keep and expand upon. Larger beams result in more accurate but slower + predictions + alpha (float): The amount of importance to place on the N-gram language model. Larger alpha means more + importance on the LM and less importance on the acoustic model. + beta (float): A penalty term given to longer word sequences. Larger beta will result in shorter sequences. + lm_path (str): Path to N-gram language model + num_cpus (int): Number of CPUs to use + cutoff_prob (float): Cutoff probability in vocabulary pruning, default 1.0, no pruning + cutoff_top_n (int): Cutoff number in pruning, only top cutoff_top_n characters with highest probs in + vocabulary will be used in beam search, default 40. + input_tensor (bool): Set to True if you intend to pass PyTorch Tensors, set to False if you intend to pass + NumPy arrays. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), + "log_probs_length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": NeuralType(('B', 'T'), PredictionsType())} + + def __init__( + self, vocab, beam_width, alpha, beta, lm_path, num_cpus, cutoff_prob=1.0, cutoff_top_n=40, input_tensor=False + ): + + try: + from ctc_decoders import Scorer, ctc_beam_search_decoder_batch + except ModuleNotFoundError: + raise ModuleNotFoundError( + "BeamSearchDecoderWithLM requires the " + "installation of ctc_decoders " + "from scripts/install_ctc_decoders.sh" + ) + + super().__init__() + + if lm_path is not None: + self.scorer = Scorer(alpha, beta, model_path=lm_path, vocabulary=vocab) + else: + self.scorer = None + self.beam_search_func = ctc_beam_search_decoder_batch + self.vocab = vocab + self.beam_width = beam_width + self.num_cpus = num_cpus + self.cutoff_prob = cutoff_prob + self.cutoff_top_n = cutoff_top_n + self.input_tensor = input_tensor + + @typecheck() + @torch.no_grad() + def forward(self, log_probs, log_probs_length): + probs_list = log_probs + if self.input_tensor: + probs = torch.exp(log_probs) + probs_list = [] + for i, prob in enumerate(probs): + probs_list.append(prob[: log_probs_length[i], :]) + res = self.beam_search_func( + probs_list, + self.vocab, + beam_size=self.beam_width, + num_processes=self.num_cpus, + ext_scoring_func=self.scorer, + cutoff_prob=self.cutoff_prob, + cutoff_top_n=self.cutoff_top_n, + ) + return res diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conformer_encoder.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e338f06c1e7c8c9fc2d44b00a8a581527c9ab2 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conformer_encoder.py @@ -0,0 +1,238 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import OrderedDict + +import torch +import torch.nn as nn + +from nemo.collections.asr.parts.conformer_modules import ConformerEncoderBlock +from nemo.collections.asr.parts.multi_head_attention import PositionalEncoding, RelPositionalEncoding +from nemo.collections.asr.parts.subsampling import ConvSubsampling +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType, SpectrogramType + +__all__ = ['ConformerEncoder'] + + +class ConformerEncoder(NeuralModule, Exportable): + """ + The encoder for ASR model of Conformer. + Based on this paper: + 'Conformer: Convolution-augmented Transformer for Speech Recognition' by Anmol Gulati et al. + https://arxiv.org/abs/2005.08100 + + Args: + feat_in (int): the size of feature channels + n_layers (int): number of layers of ConformerBlock + d_model (int): the hidden size of the model + feat_out (int): the size of the output features + Defaults to -1 (means feat_out is d_model) + subsampling (str): the method of subsampling, choices=['vggnet', 'striding'] + subsampling_factor (int): the subsampling factor which should be power of 2 + Defaults to 4. + subsampling_conv_channels (int): the size of the convolutions in the subsampling module + Defaults to 64. + ff_expansion_factor (int): the expansion factor in feed forward layers + Defaults to 4. + self_attention_model (str): type of the attention layer and positional encoding + choices=['rel_pos', 'abs_pos']. + pos_emb_max_len (int): the maximum length of positional embeddings + Defaulst to 5000 + n_heads (int): number of heads in multi-headed attention layers + Defaults to 4. + xscaling (bool): enables scaling the inputs to the multi-headed attention layers by sqrt(d_model) + Defaults to True. + conv_kernel_size (int): the size of the convolutions in the convolutional modules + Defaults to 31. + dropout (float): the dropout rate used in all layers except the attention layers + Defaults to 0.1. + dropout_emb (float): the dropout rate used for the positional embeddings + Defaults to 0.1. + dropout_att (float): the dropout rate used for the attention layer + Defaults to 0.0. + """ + + def _prepare_for_export(self): + Exportable._prepare_for_export(self) + + def input_example(self): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + input_example = torch.randn(16, self._feat_in, 256).to(next(self.parameters()).device) + return tuple([input_example]) + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return OrderedDict( + { + "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + ) + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return OrderedDict( + { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + ) + + def __init__( + self, + feat_in, + n_layers, + d_model, + feat_out=-1, + subsampling='vggnet', + subsampling_factor=4, + subsampling_conv_channels=64, + ff_expansion_factor=4, + self_attention_model='rel_pos', + pos_emb_max_len=5000, + n_heads=4, + xscaling=True, + untie_biases=False, + conv_kernel_size=31, + dropout=0.1, + dropout_emb=0.1, + dropout_att=0.0, + ): + super().__init__() + + d_ff = d_model * ff_expansion_factor + self.d_model = d_model + self.scale = math.sqrt(self.d_model) + + if xscaling: + self.xscale = math.sqrt(d_model) + else: + self.xscale = None + + if subsampling: + self.pre_encode = ConvSubsampling( + subsampling=subsampling, + subsampling_factor=subsampling_factor, + feat_in=feat_in, + feat_out=d_model, + conv_channels=subsampling_conv_channels, + activation=nn.ReLU(), + ) + self._feat_out = d_model + else: + self._feat_out = d_model + self.pre_encode = nn.Linear(feat_in, d_model) + + if not untie_biases and self_attention_model == "rel_pos": + d_head = d_model // n_heads + pos_bias_u = nn.Parameter(torch.Tensor(n_heads, d_head)) + pos_bias_v = nn.Parameter(torch.Tensor(n_heads, d_head)) + nn.init.zeros_(pos_bias_u) + nn.init.zeros_(pos_bias_v) + else: + pos_bias_u = None + pos_bias_v = None + + if self_attention_model == "rel_pos": + self.pos_enc = RelPositionalEncoding( + d_model=d_model, + dropout_rate=dropout, + max_len=pos_emb_max_len, + xscale=self.xscale, + dropout_emb_rate=dropout_emb, + ) + elif self_attention_model == "abs_pos": + pos_bias_u = None + pos_bias_v = None + self.pos_enc = PositionalEncoding( + d_model=d_model, dropout_rate=dropout, max_len=pos_emb_max_len, reverse=False, xscale=self.xscale + ) + else: + raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") + + self.layers = nn.ModuleList() + for i in range(n_layers): + layer = ConformerEncoderBlock( + d_model=d_model, + d_ff=d_ff, + conv_kernel_size=conv_kernel_size, + self_attention_model=self_attention_model, + n_heads=n_heads, + dropout=dropout, + dropout_att=dropout_att, + pos_bias_u=pos_bias_u, + pos_bias_v=pos_bias_v, + ) + self.layers.append(layer) + + if feat_out > 0 and feat_out != self.output_dim: + self.out_proj = nn.Linear(self.feat_out, feat_out) + self._feat_out = feat_out + else: + self.out_proj = None + self._feat_out = d_model + + @typecheck() + def forward(self, audio_signal, length): + audio_signal = torch.transpose(audio_signal, 1, 2) + + if isinstance(self.pre_encode, ConvSubsampling): + audio_signal, length = self.pre_encode(audio_signal, length) + else: + audio_signal = self.embed(audio_signal) + + audio_signal, pos_emb = self.pos_enc(audio_signal) + # audio_signal, pos_emb = self.pos_enc2(audio_signal) + bs, xmax, idim = audio_signal.size() + + # Create the self-attention and padding masks + pad_mask = self.make_pad_mask(length, max_time=xmax, device=audio_signal.device) + xx_mask = pad_mask.unsqueeze(1).repeat([1, xmax, 1]) + xx_mask = ~(xx_mask & xx_mask.transpose(1, 2)) + pad_mask = ~pad_mask + + for lth, layer in enumerate(self.layers): + audio_signal = layer(x=audio_signal, att_mask=xx_mask, pos_emb=pos_emb, pad_mask=pad_mask) + + if self.out_proj is not None: + audio_signal = self.out_proj(audio_signal) + + audio_signal = torch.transpose(audio_signal, 1, 2) + return audio_signal, length + + @staticmethod + def make_pad_mask(seq_lens, max_time, device=None): + """Make masking for padding.""" + bs = seq_lens.size(0) + seq_range = torch.arange(0, max_time, dtype=torch.int32) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_time) + seq_lens = seq_lens.type(seq_range_expand.dtype).to(seq_range_expand.device) + seq_length_expand = seq_lens.unsqueeze(-1) + mask = seq_range_expand < seq_length_expand + + if device: + mask = mask.to(device) + return mask diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conv_asr.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conv_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbd974385b7b2e59d818e36eb2709a8d8d6eb0a --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conv_asr.py @@ -0,0 +1,555 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import MISSING, ListConfig, OmegaConf + +from nemo.collections.asr.parts.convolution_layers import ConvNormAct, create_pad_mask, ProjUpsampling +from nemo.collections.asr.parts.jasper import ( + JasperBlock, + MaskedConv1d, + StatsPoolLayer, + init_weights, + jasper_activations, +) +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import ( + AcousticEncodedRepresentation, + LengthsType, + LogitsType, + LogprobsType, + NeuralType, + SpectrogramType, + VoidType) +from nemo.utils import logging + +__all__ = ['ConvASRDecoder', 'ConvASREncoder', 'ConvASRDecoderClassification'] + + +class ConvASREncoder(NeuralModule, Exportable): + """ + Convolutional encoder for ASR models. With this class you can implement JasperNet and QuartzNet models. + Based on these papers: + https://arxiv.org/pdf/1904.03288.pdf + https://arxiv.org/pdf/1910.10261.pdf + """ + + def _prepare_for_export(self): + m_count = 0 + for m in self.modules(): + if isinstance(m, MaskedConv1d): + m.use_mask = False + m_count += 1 + logging.warning(f"Turned off {m_count} masked convolutions") + + def input_example(self): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + input_example = torch.randn(16, self._feat_in, 256).to(next(self.parameters()).device) + return tuple([input_example]) + + @property + def disabled_deployment_input_names(self): + """Implement this method to return a set of input names disabled for export""" + return set(["length"]) + + @property + def disabled_deployment_output_names(self): + """Implement this method to return a set of output names disabled for export""" + return set(["encoded_lengths"]) + + def save_to(self, save_path: str): + pass + + @classmethod + def restore_from(cls, restore_path: str): + pass + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return OrderedDict( + { + "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + ) + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return OrderedDict( + { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "extra": NeuralType(elements_type=VoidType()), + } + ) + + def __init__( + self, + jasper, + activation: str, + feat_in: int, + normalization_mode: str = "batch", + residual_mode: str = "add", + norm_groups: int = -1, + conv_mask: bool = True, + frame_splicing: int = 1, + init_mode: Optional[str] = 'xavier_uniform', + ): + super().__init__() + if isinstance(jasper, ListConfig): + jasper = OmegaConf.to_container(jasper) + + activation = jasper_activations[activation]() + feat_in = feat_in * frame_splicing + + self._feat_in = feat_in + + residual_panes = [] + encoder_layers = [] + self.dense_residual = False + for lcfg in jasper: + dense_res = [] + if lcfg.get('residual_dense', False): + residual_panes.append(feat_in) + dense_res = residual_panes + self.dense_residual = True + groups = lcfg.get('groups', 1) + separable = lcfg.get('separable', False) + heads = lcfg.get('heads', -1) + residual_mode = lcfg.get('residual_mode', residual_mode) + se = lcfg.get('se', False) + se_reduction_ratio = lcfg.get('se_reduction_ratio', 8) + se_context_window = lcfg.get('se_context_size', -1) + se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest') + kernel_size_factor = lcfg.get('kernel_size_factor', 1.0) + stride_last = lcfg.get('stride_last', False) + encoder_layers.append( + JasperBlock( + feat_in, + lcfg['filters'], + repeat=lcfg['repeat'], + kernel_size=lcfg['kernel'], + stride=lcfg['stride'], + dilation=lcfg['dilation'], + dropout=lcfg['dropout'], + residual=lcfg['residual'], + groups=groups, + separable=separable, + heads=heads, + residual_mode=residual_mode, + normalization=normalization_mode, + norm_groups=norm_groups, + activation=activation, + residual_panes=dense_res, + conv_mask=conv_mask, + se=se, + se_reduction_ratio=se_reduction_ratio, + se_context_window=se_context_window, + se_interpolation_mode=se_interpolation_mode, + kernel_size_factor=kernel_size_factor, + stride_last=stride_last, + ) + ) + feat_in = lcfg['filters'] + + self._feat_out = feat_in + + self.encoder = torch.nn.Sequential(*encoder_layers) + self.apply(lambda x: init_weights(x, mode=init_mode)) + + @typecheck() + def forward(self, audio_signal, length=None): + s_input, length = self.encoder(([audio_signal], length)) + if length is None: + return s_input[-1] + + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + output = s_input[-1] + max_output_len = length.max() + if output.shape[2] != max_output_len: + output = output.narrow(dim=2, start=0, length=max_output_len).contiguous() + + return output, length, None + + +class ConvASRDecoder(NeuralModule, Exportable): + """Simple ASR Decoder for use with CTC-based models such as JasperNet and QuartzNet + + Based on these papers: + https://arxiv.org/pdf/1904.03288.pdf + https://arxiv.org/pdf/1910.10261.pdf + https://arxiv.org/pdf/2005.04290.pdf + """ + + def save_to(self, save_path: str): + pass + + @classmethod + def restore_from(cls, restore_path: str): + pass + + @property + def input_types(self): + return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) + + @property + def output_types(self): + return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())}) + + def __init__(self, feat_in, num_classes, *, proj_upsampling=None, conv_layers=None, projector=None, use_conv_mask=True, use_tf_pad=True, ln_eps=1e-5, + blank_pos: str = 'after_vocab_last', init_mode="xavier_uniform", vocabulary=None): + super().__init__() + + self._feat_in = feat_in + + assert vocabulary is not None + self.__vocabulary = vocabulary + if num_classes == 0: + num_classes = len(vocabulary) + elif num_classes != len(vocabulary): + raise ValueError( + f"If vocabulary is specified, it's length should be equal to the num_classes. Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" + ) + + if blank_pos == 'after_vocab_last': + # Add 1 for blank char + self._num_classes = num_classes + 1 + self.blank_idx = self._num_classes - 1 + elif blank_pos == 'vocab_first': + self._num_classes = num_classes + self.blank_idx = 0 + else: + assert blank_pos == 'vocab_last' + self._num_classes = num_classes + self.blank_idx = self._num_classes - 1 + + prev_out_channels = self._feat_in + + if proj_upsampling is not None: + self.proj_upsampling = ProjUpsampling(in_channels=prev_out_channels, + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **proj_upsampling) + prev_out_channels = proj_upsampling.filters + else: + self.proj_upsampling = None + + self.use_conv_mask = use_conv_mask + + if conv_layers is None: + conv_layers = [] + conv_layers_list = [] + for conv_cfg_i in conv_layers: + layer = ConvNormAct(in_channels=prev_out_channels, + conv_type='1d', + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **conv_cfg_i) + prev_out_channels = conv_cfg_i.filters + conv_layers_list.append(layer) + if conv_layers_list: + self.conv_layers = nn.ModuleList(conv_layers_list) + self.conv_layers.apply(lambda x: init_weights(x, mode=init_mode)) + else: + self.conv_layers = [] + + if projector is not None: + from nemo.collections.asr.parts.spec2vec import Projector + projector.input_dim = prev_out_channels + self.projector = Projector(projector) + prev_out_channels = self.projector.output_dim + else: + self.projector = None + + self.decoder_layers = torch.nn.Sequential( + torch.nn.Conv1d(prev_out_channels, self._num_classes, kernel_size=1, bias=True) + ) + self.decoder_layers.apply(lambda x: init_weights(x, mode=init_mode)) + + def forward(self, encoder_output, lens=None, log_prob=True): + + if self.proj_upsampling is not None: + encoder_output, lens = self.proj_upsampling(encoder_output, lens) + + if len(self.conv_layers) > 0 and self.use_conv_mask: + pad_mask = create_pad_mask(lens, max_len=encoder_output.size(2)) + else: + pad_mask = None + for conv_layer in self.conv_layers: + encoder_output, lens, pad_mask = conv_layer(encoder_output, lens, pad_mask=pad_mask) + + if self.projector is not None: + encoder_output = encoder_output.transpose(1, 2) + encoder_output = self.projector(encoder_output, length=lens) + encoder_output = encoder_output.transpose(1, 2) + + logits = self.decoder_layers(encoder_output).transpose(1, 2) + if log_prob: + return torch.nn.functional.log_softmax(logits, dim=-1), lens + else: + return logits, lens + + def input_example(self): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + bs = 8 + seq = 64 + input_example = torch.randn(bs, self._feat_in, seq).to(next(self.parameters()).device) + return tuple([input_example]) + + def _prepare_for_export(self): + m_count = 0 + for m in self.modules(): + if type(m).__name__ == "MaskedConv1d": + m.use_mask = False + m_count += 1 + if m_count > 0: + logging.warning(f"Turned off {m_count} masked convolutions") + Exportable._prepare_for_export(self) + + @property + def vocabulary(self): + return self.__vocabulary + + @property + def num_classes_with_blank(self): + return self._num_classes + + +class ConvASRDecoderClassification(NeuralModule, Exportable): + """Simple ASR Decoder for use with classification models such as JasperNet and QuartzNet + + Based on these papers: + https://arxiv.org/pdf/2005.04290.pdf + """ + + def input_example(self): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + input_example = torch.randn(16, self._feat_in, 128).to(next(self.parameters()).device) + return tuple([input_example]) + + @property + def input_types(self): + return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) + + @property + def output_types(self): + return OrderedDict({"logits": NeuralType(('B', 'D'), LogitsType())}) + + def __init__( + self, + feat_in: int, + num_classes: int, + init_mode: Optional[str] = "xavier_uniform", + return_logits: bool = True, + pooling_type='avg', + ): + super().__init__() + + self._feat_in = feat_in + self._return_logits = return_logits + self._num_classes = num_classes + + if pooling_type == 'avg': + self.pooling = torch.nn.AdaptiveAvgPool1d(1) + elif pooling_type == 'max': + self.pooling = torch.nn.AdaptiveMaxPool1d(1) + else: + raise ValueError('Pooling type chosen is not valid. Must be either `avg` or `max`') + + self.decoder_layers = torch.nn.Sequential(torch.nn.Linear(self._feat_in, self._num_classes, bias=True)) + self.apply(lambda x: init_weights(x, mode=init_mode)) + + @typecheck() + def forward(self, encoder_output): + batch, in_channels, timesteps = encoder_output.size() + + encoder_output = self.pooling(encoder_output).view(batch, in_channels) # [B, C] + logits = self.decoder_layers(encoder_output) # [B, num_classes] + + if self._return_logits: + return logits + + return torch.nn.functional.softmax(logits, dim=-1) + + @property + def num_classes(self): + return self._num_classes + + +class SpeakerDecoder(NeuralModule, Exportable): + """ + Speaker Decoder creates the final neural layers that maps from the outputs + of Jasper Encoder to the embedding layer followed by speaker based softmax loss. + Args: + feat_in (int): Number of channels being input to this module + num_classes (int): Number of unique speakers in dataset + emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings from 1st of this layers) + Defaults to [1024,1024] + pool_mode (str) : Pooling stratergy type. options are 'gram','xvector','superVector'. + Defaults to 'xvector' + init_mode (str): Describes how neural network parameters are + initialized. Options are ['xavier_uniform', 'xavier_normal', + 'kaiming_uniform','kaiming_normal']. + Defaults to "xavier_uniform". + """ + + def input_example(self): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + input_example = torch.randn(16, self.input_feat_in, 256).to(next(self.parameters()).device) + return tuple([input_example]) + + @property + def input_types(self): + return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) + + @property + def output_types(self): + return OrderedDict( + { + "logits": NeuralType(('B', 'D'), LogitsType()), + "embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()), + } + ) + + def __init__( + self, feat_in, num_classes, emb_sizes=None, pool_mode='xvector', angular=False, init_mode="xavier_uniform", + ): + super().__init__() + self.angular = angular + self.emb_id = 2 + if self.angular: + bias = False + else: + bias = True + + if type(emb_sizes) is str: + emb_sizes = emb_sizes.split(',') + elif type(emb_sizes) is int: + emb_sizes = [emb_sizes] + else: + emb_sizes = [512, 512] + + self.input_feat_in = feat_in + self._num_classes = num_classes + self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=pool_mode) + self._feat_in = self._pooling.feat_in + + shapes = [self._feat_in] + for size in emb_sizes: + shapes.append(int(size)) + + emb_layers = [] + for shape_in, shape_out in zip(shapes[:-1], shapes[1:]): + layer = self.affineLayer(shape_in, shape_out, learn_mean=False) + emb_layers.append(layer) + + self.emb_layers = nn.ModuleList(emb_layers) + + self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias) + + self.apply(lambda x: init_weights(x, mode=init_mode)) + + def affineLayer(self, inp_shape, out_shape, learn_mean=True): + layer = nn.Sequential( + nn.Linear(inp_shape, out_shape), + nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True), + nn.ReLU(), + ) + + return layer + + @typecheck() + def forward(self, encoder_output): + pool = self._pooling(encoder_output) + embs = [] + + for layer in self.emb_layers: + pool, emb = layer(pool), layer[: self.emb_id](pool) + embs.append(emb) + + if self.angular: + for W in self.final.parameters(): + W = F.normalize(W, p=2, dim=1) + pool = F.normalize(pool, p=2, dim=1) + + out = self.final(pool) + + return out, embs[-1] + + +@dataclass +class JasperEncoderConfig: + filters: int = MISSING + repeat: int = MISSING + kernel: List[int] = MISSING + stride: List[int] = MISSING + dilation: List[int] = MISSING + dropout: float = MISSING + residual: bool = MISSING + + # Optional arguments + groups: int = 1 + separable: bool = False + heads: int = -1 + residual_mode: str = "add" + residual_dense: bool = False + se: bool = False + se_reduction_ratio: int = 8 + se_context_size: int = -1 + se_interpolation_mode: str = 'nearest' + kernel_size_factor: float = 1.0 + stride_last: bool = False + + +@dataclass +class ConvASREncoderConfig: + _target_: str = 'nemo.collections.asr.modules.ConvASREncoder' + jasper: Optional[JasperEncoderConfig] = field(default_factory=list) + activation: str = MISSING + feat_in: int = MISSING + normalization_mode: str = "batch" + residual_mode: str = "add" + norm_groups: int = -1 + conv_mask: bool = True + frame_splicing: int = 1 + init_mode: str = "xavier_uniform" diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conv_transformer_encoder.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conv_transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e848b3cf4abb1e0f0621252229bdb8eb45471157 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/conv_transformer_encoder.py @@ -0,0 +1,166 @@ +from typing import List + +import omegaconf +import torch +import torch.nn as nn + +from nemo.collections.asr.models.configs import convtt_models_config as cfg +from nemo.collections.asr.parts.convolution_layers import ConvNormAct, create_pad_mask, Conv +from nemo.collections.common.parts.mem_transformer import RelTransformerBlock +import nemo.collections.common.parts.mem_transformer as mem_transformer + + +class ConvTransformerEncoder(nn.Module): + """ + Args: + feat_in (int): the size of feature channels + feat_out (int): the size of the output features + Defaults to -1 (means feat_out is d_model) + """ + + def __init__(self, feat_in, use_conv_mask, conv2d_block: cfg.Conv2dBlock, conv_transformer_blocks: List[cfg.ConvTransformerBlock], + use_tf_pad: bool, ln_eps: float = 1e-5, + init_mode='xavier_uniform', bias_init_mode='zero'): + super().__init__() + + freeze_config(conv2d_block, conv_transformer_blocks) + + self.use_conv_mask = use_conv_mask + + if conv2d_block: + prev_out_channels = 1 + self.conv2d_block = nn.ModuleList() + for conv2d_cfg_i in conv2d_block.layers: + layer = ConvNormAct(in_channels=prev_out_channels, + conv_type='2d', + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **conv2d_cfg_i) + prev_out_channels = conv2d_cfg_i.filters + self.conv2d_block.append(layer) + prev_out_channels = conv2d_block.output_dim + else: + self.conv2d_block = None + prev_out_channels = feat_in + + self.block_modules = nn.ModuleList() + for block_cfg in conv_transformer_blocks: + for conv_cfg_i in block_cfg.conv_layers: + layer = ConvNormAct(in_channels=prev_out_channels, + conv_type='1d', + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **conv_cfg_i) + prev_out_channels = conv_cfg_i.filters + self.block_modules.append(layer) + + if block_cfg.transformer_block is not None: + block = RelTransformerBlock(**block_cfg.transformer_block, ln_eps=ln_eps) + self.block_modules.append(block) + + self.apply(lambda x: init_weights(x, mode=init_mode, bias_mode=bias_init_mode)) + + def forward(self, audio_signal, length, match_output_len=True): + # [B, F/D, T] + output = audio_signal + + if self.use_conv_mask: + pad_mask = create_pad_mask(length, max_len=output.size(2)) + else: + pad_mask = None + + if self.conv2d_block is not None: + # [B, F, T] => [B, T, F] =>[B, C, T, F] + output = torch.transpose(output, 1, 2).unsqueeze(1) + for module in self.conv2d_block: + output, length, pad_mask = module(output, length, pad_mask=pad_mask) + b, c, t, f = output.size() + # [B, C, T, F] => [B, F, C, T] => [B, FxC/D, T] + output = output.permute(0, 3, 1, 2).reshape(b, f * c, t) + + adaptive_ffn_loss = None + for module in self.block_modules: + if isinstance(module, ConvNormAct): + output, length, pad_mask = module(output, length, pad_mask=pad_mask) + else: + assert isinstance(module, RelTransformerBlock) + # [B, D, T] => [T, B, D] + output = output.permute(2, 0, 1) + output, _, extra = module(output, lens=length) + # [T, B, D] => [B, D, T] + output = output.permute(1, 2, 0) + + if extra[0] is not None: + if adaptive_ffn_loss is None: + adaptive_ffn_loss = extra[0] + else: + adaptive_ffn_loss = adaptive_ffn_loss + extra[0] + + if match_output_len: + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = length.max() + if output.shape[2] != max_output_len: + output = output.narrow(dim=2, start=0, length=max_output_len).contiguous() + + return output, length, (adaptive_ffn_loss,) + + +def init_weights(m, mode='xavier_uniform', bias_mode='zero', embedding_mode='xavier_uniform'): + if mode == 'xavier_uniform': + init_ = nn.init.xavier_uniform_ + elif mode == 'xavier_normal': + init_ = nn.init.xavier_normal_ + else: + assert mode == 'torch_default' + init_ = lambda x: x + + if bias_mode == 'zero': + init_bias_ = nn.init.zeros_ + else: + assert bias_mode == 'torch_default' + init_bias_ = lambda x: x + + from nemo.collections.asr.modules import TransformerTDecoder, RNNTJoint + from nemo.collections.common.parts.normalization import LayerVarNorm + from nemo.collections.asr.modules.text_embed import TextEmbed + from nemo.collections.asr.modules import TextConvTransformerEncoder + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + init_(m.weight) + if m.bias is not None: + init_bias_(m.bias) + elif isinstance(m, nn.Embedding): + assert embedding_mode == mode + init_(m.weight) + if mode != 'torch_default': + if m.padding_idx is not None: + with torch.no_grad(): + m.weight[m.padding_idx].fill_(0) + elif isinstance(m, RelTransformerBlock): + if m.r_w_bias is not None: + init_(m.r_w_bias) + if m.r_r_bias is not None: + init_(m.r_r_bias) + elif isinstance(m, mem_transformer.GateNet): + init_(m.weight) + # GateNet itself will handle bias init + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, LayerVarNorm, nn.GroupNorm)): + pass # use default init + elif isinstance(m, (nn.Dropout, nn.ReLU, nn.Sequential, nn.ModuleList)): + pass # ignore modules do not need init + elif isinstance(m, (ConvNormAct, Conv, mem_transformer.RelPartialLearnableMultiHeadAttn, + mem_transformer.PositionwiseFF, mem_transformer.RelPartialLearnableDecoderLayer, + mem_transformer.PositionalEmbedding, ConvTransformerEncoder, + TransformerTDecoder, RNNTJoint, mem_transformer.AdaptiveFFN, TextEmbed, TextConvTransformerEncoder)): + pass # ignore wrapper modules + elif hasattr(m, 'HH_skip_init'): + pass + else: + raise ValueError('initializing unknown module type {}'.format(type(m))) + + +def freeze_config(*configs): + for conf in configs: + if isinstance(conf, omegaconf.Container): + omegaconf.OmegaConf.set_struct(conf, True) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/lstm_decoder.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/lstm_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5552638c8b8b8ba6f1e58e2bd3e9a0239f7e5987 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/lstm_decoder.py @@ -0,0 +1,98 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import torch +import torch.nn as nn + +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import AcousticEncodedRepresentation, LogprobsType, NeuralType + +__all__ = ['LSTMDecoder'] + + +class LSTMDecoder(NeuralModule, Exportable): + """ + Simple LSTM Decoder for ASR models + Args: + feat_in (int): size of the input features + num_classes (int): the size of the vocabulary + lstm_hidden_size (int): hidden size of the LSTM layers + vocabulary (vocab): The vocabulary + bidirectional (bool): default is False. Whether LSTMs are bidirectional or not + num_layers (int): default is 1. Number of LSTM layers stacked + """ + + @property + def input_types(self): + return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) + + @property + def output_types(self): + return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())}) + + def __init__(self, feat_in, num_classes, lstm_hidden_size, vocabulary=None, bidirectional=False, num_layers=1): + super().__init__() + + if vocabulary is not None: + if num_classes != len(vocabulary): + raise ValueError( + f"If vocabulary is specified, it's length should be equal to the num_classes. " + f"Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" + ) + self.__vocabulary = vocabulary + self._feat_in = feat_in + # Add 1 for blank char + self._num_classes = num_classes + 1 + + self.lstm_layer = nn.LSTM( + input_size=feat_in, + hidden_size=lstm_hidden_size, + num_layers=num_layers, + batch_first=True, + bidirectional=bidirectional, + ) + self.linear_layer = torch.nn.Linear(in_features=lstm_hidden_size, out_features=self._num_classes) + + @typecheck() + def forward(self, encoder_output): + output = encoder_output.transpose(1, 2) + output, _ = self.lstm_layer(output) + output = self.linear_layer(output) + return torch.nn.functional.log_softmax(output, dim=-1) + + def input_example(self): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + bs = 8 + seq = 64 + input_example = torch.randn(bs, self._feat_in, seq).to(next(self.parameters()).device) + return tuple([input_example]) + + def _prepare_for_export(self): + Exportable._prepare_for_export(self) + + @property + def vocabulary(self): + return self.__vocabulary + + @property + def num_classes_with_blank(self): + return self._num_classes diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/rnnt.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/rnnt.py new file mode 100644 index 0000000000000000000000000000000000000000..2c446b1caa132efd21b72f152e29d9778137da96 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/rnnt.py @@ -0,0 +1,950 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch + +from nemo.collections.asr.modules import rnnt_abstract +from nemo.collections.asr.modules.conv_transformer_encoder import init_weights +from nemo.collections.asr.parts import rnnt_utils +from nemo.collections.common.parts import rnn +from nemo.core.classes import typecheck +from nemo.core.neural_types import ( + AcousticEncodedRepresentation, + ElementType, + EmbeddedTextType, + LabelsType, + LengthsType, + LogprobsType, + LossType, + NeuralType, +) +from nemo.utils import logging + + +class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder): + """A Recurrent Neural Network Transducer Decoder / Prediction Network (RNN-T Prediction Network). + An RNN-T Decoder/Prediction network, comprised of a stateful LSTM model. + + Args: + prednet: A dict-like object which contains the following key-value pairs. + pred_hidden: int specifying the hidden dimension of the prediction net. + pred_rnn_layers: int specifying the number of rnn layers. + + Optionally, it may also contain the following: + forget_gate_bias: float, set by default to 1.0, which constructs a forget gate + initialized to 1.0. + Reference: + [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) + t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization + of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course + of training. + Reference: + [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + dropout: float, set to 0.0 by default. Optional dropout applied at the end of the final LSTM RNN layer. + + vocab_size: int, specifying the vocabulary size of the embedding layer of the Prediction network, + excluding the RNNT blank token. + + normalization_mode: Can be either None, 'batch' or 'layer'. By default, is set to None. + Defines the type of normalization applied to the RNN layer. + + random_state_sampling: bool, set to False by default. When set, provides normal-distribution + sampled state tensors instead of zero tensors during training. + Reference: + [Recognizing long-form speech using streaming end-to-end models](https://arxiv.org/abs/1910.11455) + + blank_as_pad: bool, set to True by default. When set, will add a token to the Embedding layer of this + prediction network, and will treat this token as a pad token. In essence, the RNNT pad token will + be treated as a pad token, and the embedding layer will return a zero tensor for this token. + + It is set by default as it enables various batch optimizations required for batched beam search. + Therefore, it is not recommended to disable this flag. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "targets": NeuralType(('B', 'T'), LabelsType()), + "target_length": NeuralType(tuple('B'), LengthsType()), + "states": NeuralType(('D', 'B', 'D'), ElementType(), optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return { + "outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + def __init__( + self, + prednet: Dict[str, Any], + vocab_size: int, + normalization_mode: Optional[str] = None, + random_state_sampling: bool = False, + blank_as_pad: bool = True, + blank_pos: str = 'after_vocab_last', + ): + # Required arguments + self.pred_hidden = prednet['pred_hidden'] + self.pred_rnn_layers = prednet["pred_rnn_layers"] + assert blank_pos == 'after_vocab_last' + self.blank_idx = vocab_size + + # Initialize the model (blank token increases vocab size by 1) + super().__init__(vocab_size=vocab_size, blank_idx=self.blank_idx, blank_as_pad=blank_as_pad) + + # Optional arguments + forget_gate_bias = prednet.get('forget_gate_bias', 1.0) + t_max = prednet.get('t_max', None) + dropout = prednet.get('dropout', 0.0) + self.random_state_sampling = random_state_sampling + + self.prediction = self._predict( + vocab_size=vocab_size, # add 1 for blank symbol + pred_n_hidden=self.pred_hidden, + pred_rnn_layers=self.pred_rnn_layers, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + norm=normalization_mode, + dropout=dropout, + ) + + @typecheck() + def forward(self, targets, target_length, states=None): + # y: (B, U) + y = rnn.label_collate(targets) + + # state maintenance is unnecessary during training forward call + # to get state, use .predict() method. + g, _ = self.predict(y, state=states, add_sos=True) # (B, U, D) + g = g.transpose(1, 2) # (B, D, U) + + return g, target_length + + def predict( + self, + y: Optional[torch.Tensor] = None, + state: Optional[List[torch.Tensor]] = None, + add_sos: bool = True, + batch_size: Optional[int] = None, + ) -> (torch.Tensor, List[torch.Tensor]): + """ + Stateful prediction of scores and state for a (possibly null) tokenset. + This method takes various cases into consideration : + - No token, no state - used for priming the RNN + - No token, state provided - used for blank token scoring + - Given token, states - used for scores + new states + + Here: + B - batch size + U - label length + H - Hidden dimension size of RNN + L - Number of RNN layers + + Args: + y: Optional torch tensor of shape [B, U] of dtype long which will be passed to the Embedding. + If None, creates a zero tensor of shape [B, 1, H] which mimics output of pad-token on Embedding. + + state: An optional list of states for the RNN. Eg: For LSTM, it is the state list length is 2. + Each state must be a tensor of shape [L, B, H]. + If None, and during training mode and `random_state_sampling` is set, will sample a + normal distribution tensor of the above shape. Otherwise, None will be passed to the RNN. + + add_sos: bool flag, whether a zero vector describing a "start of signal" token should be + prepended to the above "y" tensor. When set, output size is (B, U + 1, H). + + batch_size: An optional int, specifying the batch size of the `y` tensor. + Can be infered if `y` and `state` is None. But if both are None, then batch_size cannot be None. + + Returns: + A tuple (g, hid) such that - + + If add_sos is False: + g: (B, U, H) + hid: (h, c) where h is the final sequence hidden state and c is the final cell state: + h (tensor), shape (L, B, H) + c (tensor), shape (L, B, H) + + If add_sos is True: + g: (B, U + 1, H) + hid: (h, c) where h is the final sequence hidden state and c is the final cell state: + h (tensor), shape (L, B, H) + c (tensor), shape (L, B, H) + + """ + # Get device and dtype of current module + _p = next(self.parameters()) + device = _p.device + dtype = _p.dtype + + # If y is not None, it is of shape [B, U] with dtype long. + if y is not None: + if y.device != device: + y = y.to(device) + + # (B, U) -> (B, U, H) + y = self.prediction["embed"](y) + else: + # Y is not provided, assume zero tensor with shape [B, 1, H] is required + # Emulates output of embedding of pad token. + if batch_size is None: + B = 1 if state is None else state[0].size(1) + else: + B = batch_size + + y = torch.zeros((B, 1, self.pred_hidden), device=device, dtype=dtype) + + # Prepend blank "start of sequence" symbol (zero tensor) + if add_sos: + B, U, H = y.shape + start = torch.zeros((B, 1, H), device=y.device, dtype=y.dtype) + y = torch.cat([start, y], dim=1).contiguous() # (B, U + 1, H) + else: + start = None # makes del call later easier + + # If in training mode, and random_state_sampling is set, + # initialize state to random normal distribution tensor. + if state is None: + if self.random_state_sampling and self.training: + state = self.initialize_state(y) + + # Forward step through RNN + y = y.transpose(0, 1) # (U + 1, B, H) + g, hid = self.prediction["dec_rnn"](y, state) + g = g.transpose(0, 1) # (B, U + 1, H) + + del y, start, state + return g, hid + + def _predict(self, vocab_size, pred_n_hidden, pred_rnn_layers, forget_gate_bias, t_max, norm, dropout): + """ + Prepare the trainable parameters of the Prediction Network. + + Args: + vocab_size: Vocab size (excluding the blank token). + pred_n_hidden: Hidden size of the RNNs. + pred_rnn_layers: Number of RNN layers. + forget_gate_bias: Whether to perform unit forget gate bias. + t_max: Whether to perform Chrono LSTM init. + norm: Type of normalization to perform in RNN. + dropout: Whether to apply dropout to RNN. + """ + if self.blank_as_pad: + embed = torch.nn.Embedding(vocab_size + 1, pred_n_hidden, padding_idx=self.blank_idx) + else: + embed = torch.nn.Embedding(vocab_size, pred_n_hidden) + + layers = torch.nn.ModuleDict( + { + "embed": embed, + "dec_rnn": rnn.rnn( + input_size=pred_n_hidden, + hidden_size=pred_n_hidden, + num_layers=pred_rnn_layers, + norm=norm, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + dropout=dropout, + ), + } + ) + return layers + + def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: + """ + Initialize the state of the RNN layers, with same dtype and device as input `y`. + + Args: + y: A torch.Tensor whose device the generated states will be placed on. + + Returns: + List of torch.Tensor, each of shape [L, B, H], where + L = Number of RNN layers + B = Batch size + H = Hidden size of RNN. + """ + batch = y.size(0) + if self.random_state_sampling and self.training: + state = [ + torch.randn(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device), + torch.randn(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device), + ] + + else: + state = [ + torch.zeros(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device), + torch.zeros(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device), + ] + return state + + def score_hypothesis( + self, hypothesis: rnnt_utils.Hypothesis, cache: Dict[Tuple[int], Any] + ) -> (torch.Tensor, List[torch.Tensor], torch.Tensor): + """ + Similar to the predict() method, instead this method scores a Hypothesis during beam search. + Hypothesis is a dataclass representing one hypothesis in a Beam Search. + + Args: + hypothesis: Refer to rnnt_utils.Hypothesis. + cache: Dict which contains a cache to avoid duplicate computations. + + Returns: + Returns a tuple (y, states, lm_token) such that: + y is a torch.Tensor of shape [1, 1, H] representing the score of the last token in the Hypothesis. + state is a list of RNN states, each of shape [L, 1, H]. + lm_token is the final integer token of the hypothesis. + """ + if hypothesis.dec_state is not None: + device = hypothesis.dec_state[0].device + else: + _p = next(self.parameters()) + device = _p.device + + # parse "blank" tokens in hypothesis + if len(hypothesis.y_sequence) > 0 and hypothesis.y_sequence[-1] == self.blank_idx: + blank_state = True + else: + blank_state = False + + # Convert last token of hypothesis to torch.Tensor + target = torch.full([1, 1], fill_value=hypothesis.y_sequence[-1], device=device, dtype=torch.long) + lm_token = target[:, -1] # [1] + + # Convert current hypothesis into a tuple to preserve in cache + sequence = tuple(hypothesis.y_sequence) + + if sequence in cache: + y, new_state = cache[sequence] + else: + # Obtain score for target token and new states + if blank_state: + y, new_state = self.predict(None, state=None, add_sos=False, batch_size=1) # [1, 1, H] + + else: + y, new_state = self.predict( + target, state=hypothesis.dec_state, add_sos=False, batch_size=1 + ) # [1, 1, H] + + y = y[:, -1:, :] # Extract just last state : [1, 1, H] + cache[sequence] = (y, new_state) + + return y, new_state, lm_token + + def batch_score_hypothesis( + self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] + ) -> (torch.Tensor, List[torch.Tensor], torch.Tensor): + """ + Used for batched beam search algorithms. Similar to score_hypothesis method. + + Args: + hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis. + cache: Dict which contains a cache to avoid duplicate computations. + batch_states: List of torch.Tensor which represent the states of the RNN for this batch. + Each state is of shape [L, B, H] + + Returns: + Returns a tuple (b_y, b_states, lm_tokens) such that: + b_y is a torch.Tensor of shape [B, 1, H] representing the scores of the last tokens in the Hypotheses. + b_state is a list of list of RNN states, each of shape [L, B, H]. + Represented as B x List[states]. + lm_token is a list of the final integer tokens of the hypotheses in the batch. + """ + final_batch = len(hypotheses) + + if final_batch == 0: + raise ValueError("No hypotheses was provided for the batch!") + + _p = next(self.parameters()) + device = _p.device + dtype = _p.dtype + + tokens = [] + process = [] + done = [None for _ in range(final_batch)] + + # For each hypothesis, cache the last token of the sequence and the current states + for i, hyp in enumerate(hypotheses): + sequence = tuple(hyp.y_sequence) + + if sequence in cache: + done[i] = cache[sequence] + else: + tokens.append(hyp.y_sequence[-1]) + process.append((sequence, hyp.dec_state)) + + if process: + batch = len(process) + + # convert list of tokens to torch.Tensor, then reshape. + tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) + dec_states = self.initialize_state(tokens.to(dtype=dtype)) # [L, B, H] + dec_states = self.batch_initialize_states(dec_states, [d_state for seq, d_state in process]) + + y, dec_states = self.predict( + tokens, state=dec_states, add_sos=False, batch_size=batch + ) # [B, 1, H], List([L, 1, H]) + + # Update done states and cache shared by entire batch. + j = 0 + for i in range(final_batch): + if done[i] is None: + # Select sample's state from the batch state list + new_state = self.batch_select_state(dec_states, j) + + # Cache [1, H] scores of the current y_j, and its corresponding state + done[i] = (y[j], new_state) + cache[process[j][0]] = (y[j], new_state) + + j += 1 + + # Set the incoming batch states with the new states obtained from `done`. + batch_states = self.batch_initialize_states(batch_states, [d_state for y_j, d_state in done]) + + # Create batch of all output scores + # List[1, 1, H] -> [B, 1, H] + batch_y = torch.stack([y_j for y_j, d_state in done]) + + # Extract the last tokens from all hypotheses and convert to a tensor + lm_tokens = torch.tensor([h.y_sequence[-1] for h in hypotheses], device=device, dtype=torch.long).view( + final_batch + ) + + return batch_y, batch_states, lm_tokens + + def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): + """ + Create batch of decoder states. + + Args: + batch_states (list): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + + decoder_states (list of list): list of decoder states + [B x ([L x (1, H)], [L x (1, H)])] + + Returns: + batch_states (tuple): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + """ + # LSTM has 2 states + for layer in range(self.pred_rnn_layers): + for state_id in range(len(batch_states)): + batch_states[state_id][layer] = torch.stack([s[state_id][layer] for s in decoder_states]) + + return batch_states + + def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]: + """Get decoder state from batch of states, for given id. + + Args: + batch_states (list): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + + idx (int): index to extract state from batch of states + + Returns: + (tuple): decoder states for given id + ([L x (1, H)], [L x (1, H)]) + """ + state_list = [] + for state_id in range(len(batch_states)): + states = [batch_states[state_id][layer][idx] for layer in range(self.pred_rnn_layers)] + state_list.append(states) + + return state_list + + +class RNNTJoint(rnnt_abstract.AbstractRNNTJoint): + """A Recurrent Neural Network Transducer Joint Network (RNN-T Joint Network). + An RNN-T Joint network, comprised of a feedforward model. + + Args: + jointnet: A dict-like object which contains the following key-value pairs. + encoder_hidden: int specifying the hidden dimension of the encoder net. + pred_hidden: int specifying the hidden dimension of the prediction net. + joint_hidden: int specifying the hidden dimension of the joint net + activation: Activation function used in the joint step. Can be one of + ['relu', 'tanh', 'sigmoid']. + + Optionally, it may also contain the following: + dropout: float, set to 0.0 by default. Optional dropout applied at the end of the joint net. + + num_classes: int, specifying the vocabulary size that the joint network must predict, + excluding the RNNT blank token. + + vocabulary: Optional list of strings/tokens that comprise the vocabulary of the joint network. + Unused and kept only for easy access for character based encoding RNNT models. + + log_softmax: Optional bool, set to None by default. If set as None, will compute the log_softmax() + based on the value provided. + + preserve_memory: Optional bool, set to False by default. If the model crashes due to the memory + intensive joint step, one might try this flag to empty the tensor cache in pytorch. + + Warning: This will make the forward-backward pass much slower than normal. + It also might not fix the OOM if the GPU simply does not have enough memory to compute the joint. + + experimental_fuse_loss_wer: Optional bool, set to False by default. + NOTE: This is an experimental feature that attempts to trade of compute time for memory preservation. + There may be undetermined effects to convergence behaviour. + + Fuses the joint forward, loss forward and + wer forward steps. In doing so, it trades of speed for memory conservation by creating sub-batches + of the provided batch of inputs, and performs Joint forward, loss forward and wer forward (optional), + all on sub-batches, then collates results to be exactly equal to results from the entire batch. + + When this flag is set, prior to calling forward, the fields `loss` and `wer` (either one) *must* + be set using the `RNNTJoint.set_loss()` or `RNNTJoint.set_wer()` methods. + + Further, when this flag is set, the following argument `fused_batch_size` *must* be provided + as a non negative integer. This value refers to the size of the sub-batch. + + When the flag is set, the input and output signature of `forward()` of this method changes. + Input - in addition to `encoder_outputs` (mandatory argument), the following arguments can be provided. + - decoder_outputs (optional). Required if loss computation is required. + - encoder_lengths (required) + - transcripts (optional). Required for wer calculation. + - transcript_lengths (optional). Required for wer calculation. + - compute_wer (bool, default false). Whether to compute WER or not for the fused batch. + + Output - instead of the usual `joint` log prob tensor, the following results can be returned. + - loss (optional). Returned if decoder_outputs, transcripts and transript_lengths are not None. + - wer_numerator + wer_denominator (optional). Returned if transcripts, transcripts_lengths are provided + and compute_wer is set. + + fused_batch_size: Optional int, required if `fuse_loss_wer` flag is set. Determines the size of the + sub-batches. Should be any value below the actual batch size per GPU. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "encoder_outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "decoder_outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), + "encoder_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), + "transcripts": NeuralType(('B', 'T'), LabelsType(), optional=True), + "transcript_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), + "compute_wer": NeuralType(optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + if not self._fuse_loss_wer: + return { + "outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), + } + + else: + return { + "loss": NeuralType(elements_type=LossType(), optional=True), + "wer": NeuralType(elements_type=ElementType(), optional=True), + "wer_numer": NeuralType(elements_type=ElementType(), optional=True), + "wer_denom": NeuralType(elements_type=ElementType(), optional=True), + } + + def __init__( + self, + jointnet: Dict[str, Any], + num_classes: int, + vocabulary: Optional[List] = None, + log_softmax: Optional[bool] = None, + preserve_memory: bool = False, + experimental_fuse_loss_wer: bool = False, + fused_batch_size: Optional[int] = None, + blank_pos: str = 'after_vocab_last', + init_mode='torch_default', bias_init_mode='torch_default' + ): + super().__init__() + + self.vocabulary = vocabulary + + self._vocab_size = num_classes + if blank_pos == 'after_vocab_last': + self._num_classes = num_classes + 1 # add 1 for blank symbol + else: + assert blank_pos in ['vocab_first', 'vocab_last'] + self._num_classes = num_classes + + self._fuse_loss_wer = experimental_fuse_loss_wer + self._fused_batch_size = fused_batch_size + + if experimental_fuse_loss_wer and (fused_batch_size is None): + raise ValueError("If `fuse_loss_wer` is set, then `fused_batch_size` cannot be None!") + + if experimental_fuse_loss_wer: + logging.warning( + "\nFused joint step is an experimental technique. Please be aware that it " + "may have unintended side effects!\n" + ) + + self._loss = None + self._wer = None + + # Log softmax should be applied explicitly only for CPU + self.log_softmax = log_softmax + self.preserve_memory = preserve_memory + + if preserve_memory: + logging.warning( + "`preserve_memory` was set for the Joint Model. Please be aware this will severely impact " + "the forward-backward step time. It also might not solve OOM issues if the GPU simply " + "does not have enough memory to compute the joint." + ) + + # Required arguments + self.encoder_hidden = jointnet['encoder_hidden'] + self.pred_hidden = jointnet['pred_hidden'] + self.joint_hidden = jointnet['joint_hidden'] + self.activation = jointnet['activation'] + + # Optional arguments + dropout = jointnet.get('dropout', 0.0) + single_bias = jointnet.get('single_bias', False) + + self.pred, self.enc, self.joint_net = self._joint_net( + num_classes=self._num_classes, # add 1 for blank symbol + pred_n_hidden=self.pred_hidden, + enc_n_hidden=self.encoder_hidden, + joint_n_hidden=self.joint_hidden, + activation=self.activation, + dropout=dropout, + single_bias=single_bias + ) + + self.apply(lambda x: init_weights(x, mode=init_mode, bias_mode=bias_init_mode)) + + @typecheck() + def forward( + self, + encoder_outputs: torch.Tensor, + decoder_outputs: Optional[torch.Tensor], + encoder_lengths: Optional[torch.Tensor] = None, + transcripts: Optional[torch.Tensor] = None, + transcript_lengths: Optional[torch.Tensor] = None, + compute_wer: bool = False, + decode_results: Optional[dict] = None, + log_prediction: bool = True + ) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]: + # encoder = (B, D, T) + # decoder = (B, D, U) if passed, else None + encoder_outputs = encoder_outputs.transpose(1, 2) # (B, T, D) + + if decoder_outputs is not None: + decoder_outputs = decoder_outputs.transpose(1, 2) # (B, U, D) + + if not self._fuse_loss_wer: + if decoder_outputs is None: + raise ValueError( + "decoder_outputs passed is None, and `fuse_loss_wer` is not set. " + "decoder_outputs can only be None for fused step!" + ) + + out = self.joint(encoder_outputs, decoder_outputs) # [B, T, U, V + 1] + return out + + else: + # At least the loss module must be supplied during fused joint + if self._loss is None or self._wer is None: + raise ValueError("`fuse_loss_wer` flag is set, but `loss` and `wer` modules were not provided! ") + + # If fused joint step is required, fused batch size is required as well + if self._fused_batch_size is None: + raise ValueError("If `experimental_fuse_loss_wer` is set, then `fused_batch_size` cannot be None!") + + # When using fused joint step, both encoder and transcript lengths must be provided + if (encoder_lengths is None) or (transcript_lengths is None): + raise ValueError( + "`experimental_fuse_loss_wer` is set, therefore encoder and target lengths " + "must be provided as well!" + ) + + losses = [] + wer_numer_list = [] + wer_denom_list = [] + batch_size = int(encoder_outputs.size(0)) # actual batch size + + # Iterate over batch using fused_batch_size steps + for batch_idx in range(0, batch_size, self._fused_batch_size): + begin = batch_idx + end = min(begin + self._fused_batch_size, batch_size) + + # Extract the sub batch inputs + # sub_enc = encoder_outputs[begin:end, ...] + # sub_transcripts = transcripts[begin:end, ...] + sub_enc = encoder_outputs.narrow(dim=0, start=begin, length=end - begin) + sub_transcripts = transcripts.narrow(dim=0, start=begin, length=end - begin) + + sub_enc_lens = encoder_lengths[begin:end] + sub_transcript_lens = transcript_lengths[begin:end] + + # Sub transcripts does not need the full padding of the entire batch + # Therefore reduce the decoder time steps to match + max_sub_enc_length = sub_enc_lens.max() + max_sub_transcript_length = sub_transcript_lens.max() + + if decoder_outputs is not None: + # Reduce encoder length to preserve computation + # Encoder: [sub-batch, T, D] -> [sub-batch, T', D]; T' < T + if sub_enc.shape[1] != max_sub_enc_length: + sub_enc = sub_enc.narrow(dim=1, start=0, length=max_sub_enc_length) + + # sub_dec = decoder_outputs[begin:end, ...] # [sub-batch, U, D] + sub_dec = decoder_outputs.narrow(dim=0, start=begin, length=end - begin) # [sub-batch, U, D] + + # Reduce decoder length to preserve computation + # Decoder: [sub-batch, U, D] -> [sub-batch, U', D]; U' < U + if sub_dec.shape[1] != max_sub_transcript_length + 1: + sub_dec = sub_dec.narrow(dim=1, start=0, length=max_sub_transcript_length + 1) + + # Perform joint => [sub-batch, T', U', V + 1] + sub_joint = self.joint(sub_enc, sub_dec) + + del sub_dec + + # Reduce transcript length to correct alignment + # Transcript: [sub-batch, L] -> [sub-batch, L']; L' <= L + if sub_transcripts.shape[1] != max_sub_transcript_length: + sub_transcripts = sub_transcripts.narrow(dim=1, start=0, length=max_sub_transcript_length) + + # Compute sub batch loss + # preserve loss reduction type + loss_reduction = self.loss.reduction + + # override loss reduction to sum + self.loss.reduction = None + + # compute and preserve loss + loss_batch = self.loss( + log_probs=sub_joint, + targets=sub_transcripts, + input_lengths=sub_enc_lens, + target_lengths=sub_transcript_lens, + ) + losses.append(loss_batch) + + # reset loss reduction type + self.loss.reduction = loss_reduction + + else: + losses = None + + # Compute WER for sub batch + if compute_wer: + sub_enc = sub_enc.transpose(1, 2) # [B, T, D] -> [B, D, T] + sub_enc = sub_enc.detach() + sub_transcripts = sub_transcripts.detach() + + original_log_prediction = self.wer.log_prediction + if batch_idx == 0: + self.wer.log_prediction = True + else: + self.wer.log_prediction = False + + # Compute the wer (with logging for just 1st sub-batch) + self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens, + decode_results=decode_results, log_prediction=log_prediction) + wer, wer_num, wer_denom = self.wer.compute() + + wer_numer_list.append(wer_num) + wer_denom_list.append(wer_denom) + + # Reset logging default + self.wer.log_prediction = original_log_prediction + + else: + wer = None + + del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens + + # Collect sub batch loss results + if losses is not None: + losses = torch.cat(losses, 0) + losses = losses.mean() # global batch size average + else: + losses = None + + # Collect sub batch wer results + if compute_wer: + wer_num = torch.tensor(wer_numer_list, dtype=torch.long) + wer_denom = torch.tensor(wer_denom_list, dtype=torch.long) + + wer_num = wer_num.sum() # global sum of correct words/chars + wer_denom = wer_denom.sum() # global sum of all words/chars + else: + wer_num = None + wer_denom = None + + return losses, wer, wer_num, wer_denom + + def joint(self, f: torch.Tensor, g: torch.Tensor, apply_softmax=False) -> torch.Tensor: + """ + Compute the joint step of the network. + + Here, + B = Batch size + T = Acoustic model timesteps + U = Target sequence length + H1, H2 = Hidden dimensions of the Encoder / Decoder respectively + H = Hidden dimension of the Joint hidden step. + V = Vocabulary size of the Decoder (excluding the RNNT blank token). + + NOTE: + The implementation of this model is slightly modified from the original paper. + The original paper proposes the following steps : + (enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- *1 + *1 -> Forward through joint final [B, T, U, V + 1]. + + We instead split the joint hidden into joint_hidden_enc and joint_hidden_dec and act as follows: + enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- *1 + dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- *2 + (*1, *2) -> Sum [B, T, U, H] -> Forward through joint final [B, T, U, V + 1]. + + Args: + f: Output of the Encoder model. A torch.Tensor of shape [B, T, H1] + g: Output of the Decoder model. A torch.Tensor of shape [B, U, H2] + + Returns: + Logits / log softmaxed tensor of shape (B, T, U, V + 1). + """ + # f = [B, T, H1] + f = self.enc(f) + f = f.unsqueeze(dim=2) # (B, T, 1, H) + + # g = [B, U, H2] + g = self.pred(g) + g = g.unsqueeze(dim=1) # (B, 1, U, H) + + inp = f + g # [B, T, U, H] + + del f, g + + res = self.joint_net(inp) # [B, T, U, V + 1] + + del inp + + if self.preserve_memory: + torch.cuda.empty_cache() + + # If log_softmax is automatic + if self.log_softmax is None: + if not res.is_cuda: # Use log softmax only if on CPU + assert apply_softmax + res = res.log_softmax(dim=-1) + else: + if self.log_softmax: + assert apply_softmax + res = res.log_softmax(dim=-1) + + return res + + def _joint_net(self, num_classes, pred_n_hidden, enc_n_hidden, joint_n_hidden, activation, dropout, single_bias): + """ + Prepare the trainable modules of the Joint Network + + Args: + num_classes: Number of output classes (vocab size) excluding the RNNT blank token. + pred_n_hidden: Hidden size of the prediction network. + enc_n_hidden: Hidden size of the encoder network. + joint_n_hidden: Hidden size of the joint network. + activation: Activation of the joint. Can be one of [relu, tanh, sigmoid] + dropout: Dropout value to apply to joint. + """ + pred = torch.nn.Linear(pred_n_hidden, joint_n_hidden) + enc = torch.nn.Linear(enc_n_hidden, joint_n_hidden, bias=False if single_bias else True) + + if activation not in ['relu', 'sigmoid', 'tanh']: + raise ValueError("Unsupported activation for joint step - please pass one of " "[relu, sigmoid, tanh]") + + activation = activation.lower() + + if activation == 'relu': + activation = torch.nn.ReLU(inplace=True) + elif activation == 'sigmoid': + activation = torch.nn.Sigmoid() + elif activation == 'tanh': + activation = torch.nn.Tanh() + + layers = ( + [activation] + + ([torch.nn.Dropout(p=dropout)] if dropout else []) + + [torch.nn.Linear(joint_n_hidden, num_classes)] + ) + return pred, enc, torch.nn.Sequential(*layers) + + @property + def num_classes_with_blank(self): + return self._num_classes + + @property + def loss(self): + return self._loss + + def set_loss(self, loss): + if not self._fuse_loss_wer: + raise ValueError("Attempting to set loss module even though `fuse_loss_wer` is not set!") + + self._loss = loss + + @property + def wer(self): + return self._wer + + def set_wer(self, wer): + if not self._fuse_loss_wer: + raise ValueError("Attempting to set WER module even though `fuse_loss_wer` is not set!") + + self._wer = wer + + @property + def fuse_loss_wer(self): + return self._fuse_loss_wer + + def set_fuse_loss_wer(self, fuse_loss_wer): + self._fuse_loss_wer = fuse_loss_wer + + if self._fuse_loss_wer is False: + self._loss = None + self._wer = None + + @property + def fused_batch_size(self): + return self._fuse_loss_wer + + def set_fused_batch_size(self, fused_batch_size): + self._fused_batch_size = fused_batch_size diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/rnnt_abstract.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/rnnt_abstract.py new file mode 100644 index 0000000000000000000000000000000000000000..07213df1c2b8dbb4d646d8104498d2baa90b0377 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/rnnt_abstract.py @@ -0,0 +1,236 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from nemo.collections.asr.parts.rnnt_utils import Hypothesis +from nemo.core import NeuralModule + + +class AbstractRNNTJoint(NeuralModule, ABC): + """ + An abstract RNNT Joint framework, which can possibly integrate with GreedyRNNTInfer and BeamRNNTInfer classes. + Represents the abstract RNNT Joint network, which accepts the acoustic model and prediction network + embeddings in order to compute the joint of the two prior to decoding the output sequence. + """ + + @abstractmethod + def joint(self, f: torch.Tensor, g: torch.Tensor, apply_softmax: bool=False) -> torch.Tensor: + """ + Compute the joint step of the network. + + Here, + B = Batch size + T = Acoustic model timesteps + U = Target sequence length + H1, H2 = Hidden dimensions of the Encoder / Decoder respectively + H = Hidden dimension of the Joint hidden step. + V = Vocabulary size of the Decoder (excluding the RNNT blank token). + + NOTE: + The implementation of this model is slightly modified from the original paper. + The original paper proposes the following steps : + (enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- *1 + *1 -> Forward through joint final [B, T, U, V + 1]. + + We instead split the joint hidden into joint_hidden_enc and joint_hidden_dec and act as follows: + enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- *1 + dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- *2 + (*1, *2) -> Sum [B, T, U, H] -> Forward through joint final [B, T, U, V + 1]. + + Args: + f: Output of the Encoder model. A torch.Tensor of shape [B, T, H1] + g: Output of the Decoder model. A torch.Tensor of shape [B, U, H2] + + Returns: + Logits / log softmaxed tensor of shape (B, T, U, V + 1). + """ + raise NotImplementedError() + + @property + def num_classes_with_blank(self): + raise NotImplementedError() + + +class AbstractRNNTDecoder(NeuralModule, ABC): + """ + An abstract RNNT Decoder framework, which can possibly integrate with GreedyRNNTInfer and BeamRNNTInfer classes. + Represents the abstract RNNT Prediction/Decoder stateful network, which performs autoregressive decoding + in order to construct the output sequence. + + Args: + vocab_size: Size of the vocabulary, excluding the RNNT blank token. + blank_idx: Index of the blank token. Can be 0 or size(vocabulary). + blank_as_pad: Bool flag, whether to allocate an additional token in the Embedding layer + of this module in order to treat all RNNT `blank` tokens as pad tokens, thereby letting + the Embedding layer batch tokens more efficiently. + + It is mandatory to use this for certain Beam RNNT Infer methods - such as TSD, ALSD. + It is also more efficient to use greedy batch decoding with this flag. + """ + + def __init__(self, vocab_size, blank_idx, blank_as_pad): + super().__init__() + + self.vocab_size = vocab_size + self.blank_idx = blank_idx # first or last index of vocabulary + self.blank_as_pad = blank_as_pad + + if blank_idx not in [0, vocab_size, vocab_size - 1]: + raise ValueError("`blank_idx` must be either 0 or the final token of the vocabulary") + + @abstractmethod + def predict( + self, + y: Optional[torch.Tensor] = None, + state: Optional[torch.Tensor] = None, + add_sos: bool = False, + batch_size: Optional[int] = None, + ) -> (torch.Tensor, List[torch.Tensor]): + """ + Stateful prediction of scores and state for a (possibly null) tokenset. + This method takes various cases into consideration : + - No token, no state - used for priming the RNN + - No token, state provided - used for blank token scoring + - Given token, states - used for scores + new states + + Here: + B - batch size + U - label length + H - Hidden dimension size of RNN + L - Number of RNN layers + + Args: + y: Optional torch tensor of shape [B, U] of dtype long which will be passed to the Embedding. + If None, creates a zero tensor of shape [B, 1, H] which mimics output of pad-token on Embedding. + + state: An optional list of states for the RNN. Eg: For LSTM, it is the state list length is 2. + Each state must be a tensor of shape [L, B, H]. + If None, and during training mode and `random_state_sampling` is set, will sample a + normal distribution tensor of the above shape. Otherwise, None will be passed to the RNN. + + add_sos: bool flag, whether a zero vector describing a "start of signal" token should be + prepended to the above "y" tensor. When set, output size is (B, U + 1, H). + + batch_size: An optional int, specifying the batch size of the `y` tensor. + Can be infered if `y` and `state` is None. But if both are None, then batch_size cannot be None. + + Returns: + A tuple (g, hid) such that - + + If add_sos is False: + g: (B, U, H) + hid: (h, c) where h is the final sequence hidden state and c is the final cell state: + h (tensor), shape (L, B, H) + c (tensor), shape (L, B, H) + + If add_sos is True: + g: (B, U + 1, H) + hid: (h, c) where h is the final sequence hidden state and c is the final cell state: + h (tensor), shape (L, B, H) + c (tensor), shape (L, B, H) + + """ + raise NotImplementedError() + + @abstractmethod + def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: + """ + Initialize the state of the RNN layers, with same dtype and device as input `y`. + + Args: + y: A torch.Tensor whose device the generated states will be placed on. + + Returns: + List of torch.Tensor, each of shape [L, B, H], where + L = Number of RNN layers + B = Batch size + H = Hidden size of RNN. + """ + raise NotImplementedError() + + @abstractmethod + def score_hypothesis( + self, hypothesis: Hypothesis, cache: Dict[Tuple[int], Any] + ) -> (torch.Tensor, List[torch.Tensor], torch.Tensor): + """ + Similar to the predict() method, instead this method scores a Hypothesis during beam search. + Hypothesis is a dataclass representing one hypothesis in a Beam Search. + + Args: + hypothesis: Refer to rnnt_utils.Hypothesis. + cache: Dict which contains a cache to avoid duplicate computations. + + Returns: + Returns a tuple (y, states, lm_token) such that: + y is a torch.Tensor of shape [1, 1, H] representing the score of the last token in the Hypothesis. + state is a list of RNN states, each of shape [L, 1, H]. + lm_token is the final integer token of the hypothesis. + """ + raise NotImplementedError() + + def batch_score_hypothesis( + self, hypotheses: List[Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] + ) -> (torch.Tensor, List[torch.Tensor], torch.Tensor): + """ + Used for batched beam search algorithms. Similar to score_hypothesis method. + + Args: + hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis. + cache: Dict which contains a cache to avoid duplicate computations. + batch_states: List of torch.Tensor which represent the states of the RNN for this batch. + Each state is of shape [L, B, H] + + Returns: + Returns a tuple (b_y, b_states, lm_tokens) such that: + b_y is a torch.Tensor of shape [B, 1, H] representing the scores of the last tokens in the Hypotheses. + b_state is a list of list of RNN states, each of shape [L, B, H]. + Represented as B x List[states]. + lm_token is a list of the final integer tokens of the hypotheses in the batch. + """ + raise NotImplementedError() + + def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): + """ + Create batch of decoder states. + + Args: + batch_states (list): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + + decoder_states (list of list): list of decoder states + [B x ([L x (1, H)], [L x (1, H)])] + + Returns: + batch_states (tuple): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + """ + raise NotImplementedError() + + def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]: + """Get decoder state from batch of states, for given id. + + Args: + batch_states (list): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + + idx (int): index to extract state from batch of states + + Returns: + (tuple): decoder states for given id + ([L x (1, H)], [L x (1, H)]) + """ + raise NotImplementedError() diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/text_conv_transformer_encoder.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/text_conv_transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4d9a6c58e6a1e360f9dcabf7713c06a9378a0c --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/text_conv_transformer_encoder.py @@ -0,0 +1,102 @@ +import omegaconf +import torch.nn as nn + +from hmodels.modules.gaussian_upsampling import GaussianUpsampling +from nemo.collections.asr.models.configs import convtt_models_config as cfg +from nemo.collections.asr.models.configs.common_config import TextEmbed as TextEmbedCfg +from nemo.collections.asr.modules.conv_transformer_encoder import init_weights +from nemo.collections.asr.modules.text_embed import TextEmbed +from nemo.collections.asr.parts.convolution_layers import ConvNormAct, create_pad_mask +from nemo.collections.common.parts.mem_transformer import RelTransformerBlock + + +class TextConvTransformerEncoder(nn.Module): + def __init__(self, text_embed: TextEmbedCfg, use_conv_mask, encoder: cfg.ConvTransformerBlock, + use_tf_pad: bool, ln_eps: float = 1e-5, + init_mode='xavier_uniform', bias_init_mode='zero', embedding_init_mode='xavier_uniform', + pad_to=0, pad_value=0.0, upsampling_args=None): + super().__init__() + + freeze_config(text_embed, encoder) + + self.text_embed = TextEmbed(**text_embed, ln_eps=ln_eps, init_mode=init_mode, bias_init_mode=bias_init_mode, + embedding_init_mode=embedding_init_mode) + + self.use_conv_mask = use_conv_mask + self.pad_to = pad_to + self.pad_value = pad_value + + self.block_modules = nn.ModuleList() + for conv_cfg_i in encoder.conv_layers: + layer = ConvNormAct(in_channels=self.text_embed.embed_size, + conv_type='1d', + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **conv_cfg_i) + self.block_modules.append(layer) + + if upsampling_args: + upsampling_layer = GaussianUpsampling(upsampling_args.variance) + self.block_modules.append(upsampling_layer) + + if encoder.transformer_block is not None: + block = RelTransformerBlock(**encoder.transformer_block, ln_eps=ln_eps) + self.block_modules.append(block) + + self.apply(lambda x: init_weights(x, mode=init_mode, bias_mode=bias_init_mode)) + + def forward(self, text, length, match_output_len=True, upsampling_lens=None, + upsampling_max_len=None): + output = self.text_embed(text) + + # [B, T, D] => [B, D, T] + output = output.permute(0, 2, 1) + + if self.pad_to > 0: + # pad to multiple of pad_to (to avoid issues caused by downsampling and for efficiency) + pad_amt = output.size(-1) % self.pad_to + if pad_amt != 0: + output = nn.functional.pad(output, (0, self.pad_to - pad_amt), value=self.pad_value) + + if self.use_conv_mask: + pad_mask = create_pad_mask(length, max_len=output.size(2)) + else: + pad_mask = None + + for module in self.block_modules: + if isinstance(module, ConvNormAct): + output, length, pad_mask = module(output, length, pad_mask=pad_mask) + elif isinstance(module, GaussianUpsampling): + assert upsampling_max_len is not None + assert upsampling_lens is not None + upsampling_durations = (upsampling_lens.float() / length.float()).unsqueeze(1).expand(-1, output.shape[2]) + # [B, D, T] => [B, T, D] + output = output.permute(0, 2, 1) + output = module(output, length, upsampling_durations, upsampling_max_len) + # [B, T, D] => [B, D, T] + output = output.permute(0, 2, 1) + length = upsampling_lens + pad_mask = create_pad_mask(length, max_len=output.size(2)) + else: + assert isinstance(module, RelTransformerBlock) + # [B, D, T] => [T, B, D] + output = output.permute(2, 0, 1) + output, _, extra = module(output, lens=length) + # [T, B, D] => [B, D, T] + output = output.permute(1, 2, 0) + + if match_output_len: + # Ensure that shape mismatch does not occur due to padding + # Due to padding and subsequent downsampling, it may be possible that + # max sequence length computed does not match the actual max sequence length + max_output_len = length.max() + if output.shape[2] != max_output_len: + output = output.narrow(dim=2, start=0, length=max_output_len).contiguous() + + return output, length, None + + +def freeze_config(*configs): + for conf in configs: + if isinstance(conf, omegaconf.Container): + omegaconf.OmegaConf.set_struct(conf, True) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/text_embed.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/text_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..ce61c70a8975464704c270ae1eae388c0f81781d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/text_embed.py @@ -0,0 +1,69 @@ +from typing import Optional + +import torch + +from nemo.collections.asr.modules.conv_transformer_encoder import init_weights +from nemo.collections.common.parts import rnn + + +class TextEmbed(torch.nn.Module): + + def __init__(self, + embed_num: int, + embed_size: int, + embed_dropout: float, + embed_proj_size: int, + mask_label_prob: float = 0.0, + mask_label_id: Optional[int] = None, + norm_embed: bool = False, + ln_eps: float = 1e-5, + init_mode='xavier_uniform', bias_init_mode='zero', embedding_init_mode='xavier_uniform'): + super().__init__() + + self.embed = torch.nn.Embedding(embed_num, embed_size) + self.embed_size = embed_size + + if embed_proj_size > 0: + self.embed_proj = torch.nn.Linear(embed_size, embed_proj_size) + self.embed_size = embed_proj_size + else: + self.embed_proj = identity + + if norm_embed: + self.embed_proj_norm = torch.nn.LayerNorm(embed_proj_size if embed_proj_size else embed_size, eps=ln_eps) + else: + self.embed_proj_norm = identity + + self.embed_drop = torch.nn.Dropout(embed_dropout) + + self.mask_label_prob = mask_label_prob + self.mask_label_id = mask_label_id + assert 0.0 <= self.mask_label_prob < 1.0 + if self.mask_label_prob > 0: + assert self.mask_label_id is not None and self.mask_label_id >= 0 + assert self.mask_label_id not in [self.sos_idx, self.blank_idx] + + self.apply(lambda x: init_weights(x, mode=init_mode, bias_mode=bias_init_mode, + embedding_mode=embedding_init_mode)) + + def forward(self, texts): + y = rnn.label_collate(texts) + + if self.mask_label_prob > 0 and self.training: + y = random_replace(y, rep_prob=self.mask_label_prob, rep_id=self.mask_label_id) + + h = self.embed(y) + h = self.embed_proj(h) + h = self.embed_proj_norm(h) + h = self.embed_drop(h) + + return h + + +def random_replace(inputs: torch.Tensor, rep_prob, rep_id): + mask = torch.bernoulli(torch.full(inputs.size(), rep_prob, device=inputs.device)).type(inputs.dtype) + return mask * rep_id + (1 - mask) * inputs + + +def identity(x): + return x diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/transformer_t_decoder.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/transformer_t_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3e7c15409de45f2f3bdd1e88761a06a33f477376 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/transformer_t_decoder.py @@ -0,0 +1,384 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from nemo.collections.asr.modules import rnnt_abstract +from nemo.collections.asr.modules.conv_transformer_encoder import init_weights +from nemo.collections.asr.parts import rnnt_utils +from nemo.collections.common.parts import rnn +from nemo.collections.common.parts.mem_transformer import RelTransformerBlock +from nemo.collections.asr.models.configs import convtt_models_config as cfg +from nemo.core.classes import typecheck +from nemo.core.neural_types import ( + ElementType, + EmbeddedTextType, + LabelsType, + LengthsType, + NeuralType, +) + + +class TransformerTDecoder(rnnt_abstract.AbstractRNNTDecoder): + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "targets": NeuralType(('B', 'T'), LabelsType()), + "target_length": NeuralType(tuple('B'), LengthsType()), + "states": NeuralType(('D', 'B', 'D'), ElementType(), optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return { + "outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + def __init__(self, + vocab_size: int, + embed_size: int, + embed_dropout: float, + embed_proj_size: int, + sos_idx: Optional[int], + transformer_block: cfg.RelTransformerBlock, + blank_pos: str = 'vocab_last', + blank_as_pad: bool = True, + mask_label_prob: float = 0.0, + mask_label_id: Optional[int] = None, + norm_embed: bool = False, + norm_embed_proj: bool = False, + ln_eps: float = 1e-5, + init_mode='xavier_uniform', bias_init_mode='zero', embedding_init_mode='xavier_uniform'): + if blank_pos == 'vocab_last': + self.blank_idx = vocab_size - 1 + elif blank_pos == 'vocab_first': + self.blank_idx = 0 + else: + assert blank_pos == 'after_vocab_last' + self.blank_idx = vocab_size + + super().__init__(vocab_size=vocab_size, blank_idx=self.blank_idx, blank_as_pad=blank_as_pad) + + self.transformer_block_cfg = transformer_block + self.sos_idx = sos_idx + self.prepend_sos_label = self.sos_idx is not None + if self.prepend_sos_label: + assert self.sos_idx >= 0 + + if self.blank_as_pad: + # if blank is used as pad, ensure blank is in the input embedding + embed_num = vocab_size + 1 if blank_pos == 'after_vocab_last' else vocab_size + padding_idx = self.blank_idx + else: + embed_num = vocab_size + padding_idx = None + self.embed = torch.nn.Embedding(embed_num, embed_size, padding_idx=padding_idx) + + self.embed_drop = torch.nn.Dropout(embed_dropout) + + if norm_embed: + self.embed_norm = torch.nn.LayerNorm(embed_size, eps=ln_eps) + else: + self.embed_norm = identity + + if embed_proj_size > 0: + self.embed_proj = torch.nn.Linear(embed_size, embed_proj_size) + else: + self.embed_proj = identity + + if norm_embed_proj: + assert embed_proj_size > 0 + self.embed_proj_norm = torch.nn.LayerNorm(embed_proj_size, eps=ln_eps) + else: + self.embed_proj_norm = identity + + self.mask_label_prob = mask_label_prob + self.mask_label_id = mask_label_id + assert 0.0 <= self.mask_label_prob < 1.0 + if self.mask_label_prob > 0: + assert self.mask_label_id is not None and self.mask_label_id >= 0 + assert self.mask_label_id not in [self.sos_idx, self.blank_idx] + + self.transformer_block = RelTransformerBlock(**self.transformer_block_cfg, ln_eps=ln_eps) + + self.apply(lambda x: init_weights(x, mode=init_mode, bias_mode=bias_init_mode, + embedding_mode=embedding_init_mode)) + + @typecheck() + def forward(self, targets, target_length, states=None): + # y: [B, U] + y = rnn.label_collate(targets) + + is_decoding = states is not None + if self.prepend_sos_label and not is_decoding: + y = torch.nn.functional.pad(y, [1, 0], value=self.sos_idx) + # we pad y, not targets, so target_length do not change + # target_length = target_length + 1 + + if self.mask_label_prob > 0 and self.training: + y = random_replace(y, rep_prob=self.mask_label_prob, rep_id=self.mask_label_id) + + # state maintenance is unnecessary during training forward call + # to get state, use .predict() method. + h, _ = self.predict(y, state=states, add_sos=not self.prepend_sos_label) + # [B, U, H] => [B, H, U+1] + h = h.transpose(1, 2) + + return h, target_length + + def predict(self, + y: Optional[torch.Tensor] = None, + state: Optional[List[torch.Tensor]] = None, + add_sos: bool = True, + batch_size: Optional[int] = None) -> (torch.Tensor, List[torch.Tensor]): + # Get device and dtype of current module + _p = next(self.parameters()) + device = _p.device + dtype = _p.dtype + + # If y is not None, it is of shape [B, U] with dtype long. + if y is not None: + if y.device != device: + y = y.to(device) + # (B, U) -> (B, U, H) + h = self.embed(y) + h = self.embed_drop(h) + h = self.embed_norm(h) + else: + assert not self.prepend_sos_label + # Y is not provided, assume state tensor is required + # Emulates output of embedding of pad token. + if batch_size is None: + B = 1 if state is None else state[0][0].size(1) + else: + B = batch_size + + h = torch.zeros((B, 1, self.embed.embedding_dim), device=device, dtype=dtype) + + # Prepend blank "start of sequence" symbol (zero tensor) + if add_sos: + assert not self.prepend_sos_label + B, U, H = h.shape + start = torch.zeros((B, 1, H), device=h.device, dtype=h.dtype) + h = torch.cat([start, h], dim=1).contiguous() # (B, U + 1, H) + else: + start = None # makes del call later easier + + h = self.embed_proj(h) + h = self.embed_proj_norm(h) + + # [B, U+1, H] => [U+1, B, H] + h = h.transpose(0, 1) + h, new_state, _ = self.transformer_block(h, mems=state) + # [U+1, B, H] => [B, U+1, H] + h = h.transpose(0, 1) + + del start, state + return h, new_state + + def initialize_state(self, batch_size, dtype, device) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + # assume y has shape [T, B, D] + transformer_block_params = self.transformer_block_cfg + n_layer = transformer_block_params['n_layer'] + n_head = transformer_block_params['n_head'] + d_head = transformer_block_params['d_head'] + use_mq_attn = transformer_block_params.get('use_mq_attn', False) + + if use_mq_attn: + init_state = [(torch.zeros(0, batch_size, d_head, dtype=dtype, device=device), + torch.zeros(0, batch_size, d_head, dtype=dtype, device=device), + torch.tensor(0, dtype=torch.int32, device=device)) for _ in range(n_layer)] + else: + init_state = [(torch.zeros(0, batch_size, n_head, d_head, dtype=dtype, device=device), + torch.zeros(0, batch_size, n_head, d_head, dtype=dtype, device=device), + torch.tensor(0, dtype=torch.int32, device=device)) for _ in range(n_layer)] + return init_state + + def score_hypothesis( + self, hypothesis: rnnt_utils.Hypothesis, cache: Dict[Tuple[int], Any] + ) -> (torch.Tensor, List[torch.Tensor], torch.Tensor): + """ + Similar to the predict() method, instead this method scores a Hypothesis during beam search. + Hypothesis is a dataclass representing one hypothesis in a Beam Search. + + Args: + hypothesis: Refer to rnnt_utils.Hypothesis. + cache: Dict which contains a cache to avoid duplicate computations. + + Returns: + Returns a tuple (y, states, lm_token) such that: + y is a torch.Tensor of shape [1, 1, H] representing the score of the last token in the Hypothesis. + state is a list of RNN states, each of shape [L, 1, H]. + lm_token is the final integer token of the hypothesis. + """ + if hypothesis.dec_state is not None: + device = hypothesis.dec_state[0].device + else: + _p = next(self.parameters()) + device = _p.device + + # parse "blank" tokens in hypothesis + if len(hypothesis.y_sequence) > 0 and hypothesis.y_sequence[-1] == self.blank_idx: + blank_state = True + else: + blank_state = False + + # Convert last token of hypothesis to torch.Tensor + target = torch.full([1, 1], fill_value=hypothesis.y_sequence[-1], device=device, dtype=torch.long) + lm_token = target[:, -1] # [1] + + # Convert current hypothesis into a tuple to preserve in cache + sequence = tuple(hypothesis.y_sequence) + + if sequence in cache: + y, new_state = cache[sequence] + else: + # Obtain score for target token and new states + if blank_state: + y, new_state = self.predict(None, state=None, add_sos=False, batch_size=1) # [1, 1, H] + + else: + y, new_state = self.predict( + target, state=hypothesis.dec_state, add_sos=False, batch_size=1 + ) # [1, 1, H] + + y = y[:, -1:, :] # Extract just last state : [1, 1, H] + cache[sequence] = (y, new_state) + + return y, new_state, lm_token + + def batch_score_hypothesis( + self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] + ) -> (torch.Tensor, List[torch.Tensor], torch.Tensor): + """ + Used for batched beam search algorithms. Similar to score_hypothesis method. + + Args: + hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis. + cache: Dict which contains a cache to avoid duplicate computations. + batch_states: List of torch.Tensor which represent the states of the RNN for this batch. + Each state is of shape [L, B, H] + + Returns: + Returns a tuple (b_y, b_states, lm_tokens) such that: + b_y is a torch.Tensor of shape [B, 1, H] representing the scores of the last tokens in the Hypotheses. + b_state is a list of list of RNN states, each of shape [L, B, H]. + Represented as B x List[states]. + lm_token is a list of the final integer tokens of the hypotheses in the batch. + """ + final_batch = len(hypotheses) + + if final_batch == 0: + raise ValueError("No hypotheses was provided for the batch!") + + _p = next(self.parameters()) + device = _p.device + dtype = _p.dtype + + tokens = [] + process = [] + done = [None for _ in range(final_batch)] + + # For each hypothesis, cache the last token of the sequence and the current states + for i, hyp in enumerate(hypotheses): + sequence = tuple(hyp.y_sequence) + + if sequence in cache: + done[i] = cache[sequence] + else: + tokens.append(hyp.y_sequence[-1]) + process.append((sequence, hyp.dec_state)) + + if process: + batch = len(process) + + # convert list of tokens to torch.Tensor, then reshape. + tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) + dec_states = self.initialize_state(tokens.to(dtype=dtype)) # [L, B, H] + dec_states = self.batch_initialize_states(dec_states, [d_state for seq, d_state in process]) + + y, dec_states = self.predict( + tokens, state=dec_states, add_sos=False, batch_size=batch + ) # [B, 1, H], List([L, 1, H]) + + # Update done states and cache shared by entire batch. + j = 0 + for i in range(final_batch): + if done[i] is None: + # Select sample's state from the batch state list + new_state = self.batch_select_state(dec_states, j) + + # Cache [1, H] scores of the current y_j, and its corresponding state + done[i] = (y[j], new_state) + cache[process[j][0]] = (y[j], new_state) + + j += 1 + + # Set the incoming batch states with the new states obtained from `done`. + batch_states = self.batch_initialize_states(batch_states, [d_state for y_j, d_state in done]) + + # Create batch of all output scores + # List[1, 1, H] -> [B, 1, H] + batch_y = torch.stack([y_j for y_j, d_state in done]) + + # Extract the last tokens from all hypotheses and convert to a tensor + lm_tokens = torch.tensor([h.y_sequence[-1] for h in hypotheses], device=device, dtype=torch.long).view( + final_batch + ) + + return batch_y, batch_states, lm_tokens + + def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): + """ + Create batch of decoder states. + + Args: + batch_states (list): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + + decoder_states (list of list): list of decoder states + [B x ([L x (1, H)], [L x (1, H)])] + + Returns: + batch_states (tuple): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + """ + # LSTM has 2 states + for layer in range(self.pred_rnn_layers): + for state_id in range(len(batch_states)): + batch_states[state_id][layer] = torch.stack([s[state_id][layer] for s in decoder_states]) + + return batch_states + + def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]: + """Get decoder state from batch of states, for given id. + + Args: + batch_states (list): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + + idx (int): index to extract state from batch of states + + Returns: + (tuple): decoder states for given id + ([L x (1, H)], [L x (1, H)]) + """ + state_list = [] + for state_id in range(len(batch_states)): + states = [batch_states[state_id][layer][idx] for layer in range(self.pred_rnn_layers)] + state_list.append(states) + + return state_list + + +def random_replace(inputs: torch.Tensor, rep_prob, rep_id): + mask = torch.bernoulli(torch.full(inputs.size(), rep_prob, device=inputs.device)).type(inputs.dtype) + return mask * rep_id + (1 - mask) * inputs + + +def identity(x): + return x diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/wav2vec_encoder.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/wav2vec_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6e33e0f722680950a65d04a5b4089916c6b059f5 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/wav2vec_encoder.py @@ -0,0 +1,395 @@ +import torch +from torch import nn + +from nemo.collections.asr.modules.wav2vec_modules import GumbelVectorQuantizer, compute_mask_indices +from nemo.collections.asr.parts.wav2vec import ConvFeatureEncoder, Wav2VecTransformerEncoder, GradMultiply, \ + TransformerEncoder + + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +class Wav2VecEncoderModel(nn.Module): + def __init__(self, cfg): + super().__init__() + + feature_enc_layers = cfg.conv_feature_encoder.conv_feature_layers + self.embed = feature_enc_layers[-1][0] # Select last conv output layer dimension + + self.feature_extractor = ConvFeatureEncoder( + conv_layers=feature_enc_layers, + mode=cfg.conv_feature_encoder.extractor_mode, + conv_bias=cfg.conv_feature_encoder.conv_bias, + ) + + encoder_embed_dim = cfg.transformer_encoder.encoder.embedding_dim + self.post_extract_proj = ( + nn.Linear(self.embed, encoder_embed_dim) + if self.embed != encoder_embed_dim and not cfg.quantizer.quantize_input + else None + ) + assert not cfg.quantizer.quantize_input # finetune expect this + + self.mask_cfg = cfg.masking + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.quantizer = None + self.input_quantizer = None + + self.n_negatives = cfg.n_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere + + final_dim = cfg.final_dim if cfg.final_dim > 0 else encoder_embed_dim + self.final_dim = final_dim + self.quantize_targets = cfg.quantizer.quantize_targets + if self.quantize_targets: + assert cfg.quantizer.targets_bottleneck_dim is None + vq_dim = cfg.quantizer.latent_dim if cfg.quantizer.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_q = nn.Linear(vq_dim, final_dim) + else: + assert cfg.ctc_loss.prob_ppl_weight == 0 + targets_bottleneck_dim = cfg.quantizer.targets_bottleneck_dim + if targets_bottleneck_dim is None: + self.project_q = nn.Linear(self.embed, final_dim) + else: + act_fn_dic = {'relu': nn.ReLU, 'gelu': nn.GELU} + targets_proj_act_fn = cfg.quantizer.targets_bottleneck_act_fn + targets_proj_layers = ( + [nn.Linear(self.embed, targets_bottleneck_dim)] + + ([] if targets_proj_act_fn is None else [act_fn_dic[targets_proj_act_fn]]) + + [nn.Linear(targets_bottleneck_dim, final_dim)] + + ) + self.project_q = torch.nn.Sequential(*targets_proj_layers) + + if cfg.quantizer.quantize_input: + if cfg.quantizer.same_quantizer and self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = cfg.quantizer.latent_dim if cfg.quantizer.latent_dim > 0 else encoder_embed_dim + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.quantizer.latent_vars, + temp=cfg.quantizer.latent_temp, + groups=cfg.quantizer.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_inp = nn.Linear(vq_dim, encoder_embed_dim) + + self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_embed_dim).uniform_()) + + if cfg.transformer_encoder.use_pytorch_transformer: + self.encoder = Wav2VecTransformerEncoder(cfg.transformer_encoder) + else: + self.encoder = TransformerEncoder(cfg.transformer_encoder) + self.layer_norm = nn.LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU()) + + self.final_proj = nn.Linear(encoder_embed_dim, final_dim) + + def forward(self, source, source_len, *, mask=True, features_only=False) -> tuple: + prob_ppl_loss, cur_temp = None, None + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + feature_lens = self.feature_extractor.get_subsampled_lens(source_len) + padding_mask = self._create_padding_mask(feature_lens) + assert feature_lens.max() == features.shape[2] == padding_mask.shape[1] + + features = features.transpose(1, 2) + + features_penalty = None if features_only else features[~padding_mask].float().pow(2).mean() # L2 Norm on features + + features = self.layer_norm(features) + unmasked_features = None if features_only else features.clone() + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + if not features_only: + unmasked_features = self.dropout_features(unmasked_features) + + assert self.input_quantizer is None + # if self.input_quantizer: + # features, prob_ppl_loss, cur_codebook_temp = self.input_quantizer(features) + # features = self.project_inp(features) + if mask: + logits, mask_indices, mask_num = self.apply_mask(features, padding_mask) + if features_only: + targets = None + elif mask_indices is not None: + targets = unmasked_features[mask_indices] + if self.mask_cfg.mask_shrink_to_batch_min: + targets = targets.view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + # fake batch dim 1 + targets = targets.view( + 1, -1, unmasked_features.size(-1) + ) + assert targets.shape[1] == sum(mask_num) + else: + targets = unmasked_features + else: + logits = features + targets = None if features_only else unmasked_features + mask_indices = None + mask_num = None + + logits = self.encoder(logits, padding_mask=padding_mask) + + if features_only: + return logits, feature_lens + + if self.quantize_targets: + targets, prob_ppl_loss, cur_temp, prob_ppl = self.quantizer(targets) + targets = self.project_q(targets) + + if self.negatives_from_everywhere: + assert self.mask_cfg.mask_shrink_to_batch_min + neg_cands, *_ = self.quantizer(unmasked_features) + sampled_negatives, _ = self.sample_negatives(neg_cands, targets.size(1)) + sampled_negatives = self.project_q(sampled_negatives) + else: + if self.mask_cfg.mask_shrink_to_batch_min: + sampled_negatives, _ = self.sample_negatives(targets, targets.size(1)) + else: + sampled_negatives, _ = self.sample_negatives_flat(targets, mask_num) + + if self.codebook_negatives > 0: + assert self.mask_cfg.mask_shrink_to_batch_min + cb_negs = self.quantizer.sample_from_codebook( + targets.size(0) * targets.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, targets.size(0), targets.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + sampled_negatives = torch.cat([sampled_negatives, cb_negs], dim=0) + else: + targets = self.project_q(targets) + prob_ppl = None + + if self.negatives_from_everywhere: + assert self.mask_cfg.mask_shrink_to_batch_min + sampled_negatives, _ = self.sample_negatives(unmasked_features, targets.size(1)) + sampled_negatives = self.project_q(sampled_negatives) + else: + if self.mask_cfg.mask_shrink_to_batch_min: + sampled_negatives, _ = self.sample_negatives(targets, targets.size(1)) + else: + sampled_negatives, _ = self.sample_negatives_flat(targets, mask_num) + + mask_logits = logits[mask_indices] + if self.mask_cfg.mask_shrink_to_batch_min: + mask_logits = mask_logits.view(logits.size(0), -1, logits.size(-1)) + else: + # fake batch dim to 1 + mask_logits = mask_logits.view(1, -1, logits.size(-1)) + + if self.target_glu: + targets = self.target_glu(targets) + sampled_negatives = self.target_glu(sampled_negatives) + + mask_logits = self.final_proj(mask_logits) + + return mask_logits, targets, sampled_negatives, padding_mask, features_penalty, prob_ppl_loss, cur_temp, prob_ppl + + def extract_features(self, source, audio_lengths, mask=False): + padding_mask = self._create_padding_mask(audio_lengths) + return self(source=source, padding_mask=padding_mask, mask=mask, features_only=True) + + def remove_pretraining_modules(self): + self.quantizer = None + self.project_q = None + self.target_glu = None + self.final_proj = None + self.dropout_features = None + self.input_quantizer = None + self.project_q = None + self.project_inp = None + self.target_glu = None + + def _update_quantizer_temp(self): + if self.quantizer: + self.quantizer.set_num_updates(self.trainer.global_step) + if self.input_quantizer: + self.input_quantizer.set_num_updates(self.trainer.global_step) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_cfg.mask_prob > 0: + mask_indices, mask_num = compute_mask_indices( + (B, T), + padding_mask, + self.mask_cfg.mask_prob, + self.mask_cfg.mask_length, + self.mask_cfg.mask_type, + self.mask_cfg.mask_other, + min_masks=2, + no_overlap=self.mask_cfg.no_mask_overlap, + min_space=self.mask_cfg.mask_min_space, + shrink_to_batch_min=self.mask_cfg.mask_shrink_to_batch_min, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + mask_emb = self.mask_emb.type_as(x) + x[mask_indices] = mask_emb + else: + mask_indices = None + + if self.mask_cfg.mask_channel_prob > 0: + # assert self.mask_cfg.mask_shrink_to_batch_min + mask_channel_indices, _ = compute_mask_indices( + (B, C), + None, + self.mask_cfg.mask_channel_prob, + self.mask_cfg.mask_channel_length, + self.mask_cfg.mask_channel_type, + self.mask_cfg.mask_channel_other, + no_overlap=self.mask_cfg.no_mask_channel_overlap, + min_space=self.mask_cfg.mask_channel_min_space, + shrink_to_batch_min=self.mask_cfg.mask_channel_shrink_to_batch_min, + ) + mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) + x[mask_channel_indices] = 0 + + assert len(mask_num) == B + return x, mask_indices, mask_num + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz, tsz, fsz}" + + if self.n_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + neg_idxs = torch.randint(low=0, high=high - 1, size=(bsz, self.n_negatives * num)) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = buffered_arange(num).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten() + + cross_neg_idxs = torch.randint( + low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view(bsz, num, self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def sample_negatives_flat(self, y, nums): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + assert bsz == 1 and tsz == sum(nums) # fake batch dim + y = y.view(-1, fsz) # BTC => (BxT)C + + # cross_high = tsz * bsz + + neg_idxs_l = [] + idx_start = 0 + with torch.no_grad(): + for i, num_i in enumerate(nums): + assert num_i > 1, f"{bsz, tsz, fsz}" + + assert self.n_negatives > 0 + tszs_i = buffered_arange(num_i).unsqueeze(-1).expand(-1, self.n_negatives).flatten() + + high_i = num_i + neg_idxs_i = torch.randint(low=0, high=high_i - 1, size=(self.n_negatives * num_i,)) + neg_idxs_i[neg_idxs_i >= tszs_i] += 1 + + neg_idxs_i += idx_start + idx_start += num_i + + neg_idxs_l.append(neg_idxs_i) + + assert self.cross_sample_negatives == 0 + # if self.cross_sample_negatives > 0: + # tszs = buffered_arange(num_i).unsqueeze(-1).expand(-1, self.cross_sample_negatives).flatten() + # + # cross_neg_idxs = torch.randint( + # low=0, high=cross_high - 1, size=(self.cross_sample_negatives * num_i), + # ) + # cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + # if self.n_negatives <= 0: + # neg_idxs = cross_neg_idxs + + # if self.cross_sample_negatives > 0 and self.n_negatives > 0: + # neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + neg_idxs = torch.cat(neg_idxs_l) + assert neg_idxs.ndim == 1 + + negs = y[neg_idxs] + negs = negs.view(bsz, sum(nums), self.n_negatives + self.cross_sample_negatives, fsz).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def _create_padding_mask(self, audio_lengths): + # Broadcast to vectorize creating the padding mask + max_len = max(audio_lengths) + padding_mask = torch.arange(max_len, device=audio_lengths.device) + padding_mask = padding_mask.expand(len(audio_lengths), max_len) < audio_lengths.unsqueeze(1) + # Negate to false where no padding + padding_mask = ~padding_mask + return padding_mask diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/wav2vec_modules.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/wav2vec_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..5c4c2c516b899b9c313504dd77fe645c1dfe97fa --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/wav2vec_modules.py @@ -0,0 +1,328 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from nemo.collections.asr.models.wav2vec.wav2vec_config import Wav2VecMaskType +from nemo.core import NeuralModule +from nemo.core.neural_types import EncodedRepresentation, LossType, NeuralType + + +class GumbelVectorQuantizer(NeuralModule): + def __init__( + self, + dim, + num_vars, + temp, + groups, + combine_groups, + vq_dim, + time_first, + activation=nn.GELU(), + weight_proj_depth=1, + weight_proj_factor=1, + ): + """Vector quantization using gumbel softmax + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 + weight_proj_depth: number of layers (with activation in between) to project input before computing logits + weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of + projections by this factor + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.time_first = time_first + + assert vq_dim % groups == 0, f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) + nn.init.uniform_(self.vars) + + if weight_proj_depth > 1: + + def block(input_dim, output_dim): + return nn.Sequential(nn.Linear(input_dim, output_dim), activation) + + inner_dim = self.input_dim * weight_proj_factor + self.weight_proj = nn.Sequential( + *[block(self.input_dim if i == 0 else inner_dim, inner_dim) for i in range(weight_proj_depth - 1)], + nn.Linear(inner_dim, groups * num_vars), + ) + else: + self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.init.normal_(self.weight_proj.weight, mean=0, std=1) + nn.init.zeros_(self.weight_proj.bias) + + assert len(temp) == 3, "Quantize temperature should be a tuple of 3 elements: (start, stop, decay factor)" + + self.max_temp, self.min_temp, self.temp_decay = temp + self.curr_temp = self.max_temp + self.codebook_indices = None + + def set_num_updates(self, num_updates): + self.curr_temp = max(self.max_temp * self.temp_decay ** num_updates, self.min_temp) + + def get_codebook_indices(self): + if self.codebook_indices is None: + from itertools import product + + p = [range(self.num_vars)] * self.groups + inds = list(product(*p)) + self.codebook_indices = torch.tensor(inds, dtype=torch.long, device=self.vars.device).flatten() + + if not self.combine_groups: + self.codebook_indices = self.codebook_indices.view(self.num_vars ** self.groups, -1) + for b in range(1, self.groups): + self.codebook_indices[:, b] += self.num_vars * b + self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.view(-1, self.groups) + cb_size = indices.size(0) + assert n < cb_size, f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) + return z + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + if self.time_first: + return {"x": NeuralType(('B', 'T', 'D'), EncodedRepresentation())} + return {"x": NeuralType(('B', 'D', 'T'), EncodedRepresentation())} + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + if self.time_first: + return { + "x": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), + "quantize_prob_ppl": NeuralType(elements_type=LossType()), + } + return { + "x": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "quantize_prob_ppl": NeuralType(elements_type=LossType()), + } + + def forward(self, x, quant_only=False): + + if not self.time_first: + x = x.transpose(1, 2) + + bsz, tsz, fsz = x.shape + x = x.reshape(-1, fsz) + x = self.weight_proj(x) + x = x.view(bsz * tsz * self.groups, -1) + + if self.training: + hard_x = None + else: + _, k = x.max(-1) + hard_x = x.new_zeros(*x.shape).scatter_(-1, k.view(-1, 1), 1.0).view(bsz * tsz, self.groups, -1) + + # Calculate quantize prob perplexity + num_vars = self.num_vars * self.groups + avg_probs = torch.softmax(x.view(bsz * tsz, self.groups, -1).float(), dim=-1).mean(dim=0) + prob_ppl = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)).sum() + prob_ppl_loss = (num_vars - prob_ppl) / num_vars + + if self.training: + x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x) + else: + x = hard_x + + x = x.view(bsz * tsz, -1) + + vars = self.vars + if self.combine_groups: + vars = vars.repeat(1, self.groups, 1) + + if quant_only: + quant_ids = ( + x.view(bsz * tsz * self.groups, -1) + .argmax(dim=-1) + .view(bsz, tsz, self.groups) + .detach() + ) + + if self.groups == 1: + x = torch.matmul(x, vars.squeeze(0)) + else: + x = x.unsqueeze(-1) * vars + x = x.view(bsz * tsz, self.groups, self.num_vars, -1) + x = x.sum(-2) + x = x.view(bsz, tsz, -1) + + cur_codebook_temp = self.curr_temp + + if not self.time_first: + x = x.transpose(1, 2) # BTC -> BCT + + if quant_only: + return x, quant_ids + + return x, prob_ppl_loss, cur_codebook_temp, prob_ppl + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: Wav2VecMaskType = Wav2VecMaskType.static, + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + shrink_to_batch_min: bool = True, + mask_positions = None +): + """ + Computes random mask spans for a given shape + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type.value is Wav2VecMaskType.static.value: + lengths = np.full(num_mask, mask_length) + elif mask_type.value is Wav2VecMaskType.uniform: + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type.value is Wav2VecMaskType.normal.value: + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type.value is Wav2VecMaskType.poisson.value: + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + str(mask_type)) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + mask_num = [len(m) for m in mask_idcs] + min_len = min(mask_num) + for i, mask_idc in enumerate(mask_idcs): + if shrink_to_batch_min and len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + if mask_positions is not None: + mask_positions.append(mask_idc) + + return mask, mask_num diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/yin.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/yin.py new file mode 100644 index 0000000000000000000000000000000000000000..acd56ff8ab299ffa2c4cf04d46c18e64b0eeca3d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/modules/yin.py @@ -0,0 +1,125 @@ +""" Yin Pitch Estimation Algorithm + +This is a fully-vectorized implementation of the Yin pitch estimation +algorithm for PyTorch. It is based on the excellent NumPy baseline by +Patrice Guyot (github.com/patriceguyot), with changes for batching and +vectorization of the iterative search. + +References: + https://asa.scitation.org/doi/10.1121/1.1458024 + https://github.com/patriceguyot/Yin + +License: + MIT License + Copyright Ā© 2022 Brent M. Spell + +""" + +import typing as T + +import numpy as np +import torch + + +def estimate( + signal: T.Union[T.List, np.ndarray, torch.Tensor], + sample_rate: float, + pitch_min: float = 20, + pitch_max: float = 20000, + frame_stride: float = 0.01, + threshold: float = 0.1, +) -> torch.Tensor: + """estimate the pitch (fundamental frequency) of a signal + + This function attempts to determine the pitch of a signal via the + Yin algorithm. Accuracy can be improved by sampling the signal at a + higher rate, especially for higher-frequency pitches, and by narrowing + the values of pitch_min and pitch_max. For example, good values for + speech signals are pitch_min=60 and pitch_max=500. frame_stride can also + be tuned to the expected minimum rate of pitch change of the signal: + 10ms is commonly used for speech. + + The speed and memory usage of the algorithm are also determined by the + pitch_min parameter, which is used to window the audio signal into + 2*sample_rate/pitch_min sliding windows. A higher pitch_min corresponds to + less memory usage and faster running time. + + Args: + signal: the signal vector (1D) or [batch, time] tensor to analyze + sample_rate: sample rate, in Hz, of the signal + pitch_min: expected lower bound of the pitch + pitch_max: expected upper bound of the pitch + frame_stride: overlapping window stride, in seconds, which determines + the number of pitch values returned + threshold: harmonic threshold value (see paper) + + Returns: + pitch: PyTorch tensor of pitch estimations, one for each frame of + the windowed signal, an entry of 0 corresponds to a non-periodic + frame, where no periodic signal was detected + + """ + + signal = torch.as_tensor(signal) + + # convert frequencies to samples, ensure windows can fit 2 whole periods + tau_min = int(sample_rate / pitch_max) + tau_max = int(sample_rate / pitch_min) + frame_length = 2 * tau_max + frame_stride = int(frame_stride * sample_rate) + + # compute the fundamental periods + frames = _frame(signal, frame_length, frame_stride) + cmdf = _diff(frames, tau_max)[..., tau_min:] + tau = _search(cmdf, tau_max, threshold) + + # convert the periods to frequencies (if periodic) and output + return torch.where( + tau > 0, + sample_rate / (tau + tau_min + 1).type(signal.dtype), + torch.tensor(0, device=tau.device).type(signal.dtype), + ) + + +def _frame(signal: torch.Tensor, frame_length: int, frame_stride: int) -> torch.Tensor: + # window the signal into overlapping frames, padding to at least 1 frame + if signal.shape[-1] < frame_length: + signal = torch.nn.functional.pad(signal, [0, frame_length - signal.shape[-1]]) + return signal.unfold(dimension=-1, size=frame_length, step=frame_stride) + + +def _diff(frames: torch.Tensor, tau_max: int) -> torch.Tensor: + # compute the frame-wise autocorrelation using the FFT + fft_size = 2 ** (-int(-np.log(frames.shape[-1]) // np.log(2)) + 1) + fft = torch.fft.rfft(frames, fft_size, dim=-1) + corr = torch.fft.irfft(fft * fft.conj())[..., :tau_max] + + # difference function (equation 6) + sqrcs = torch.nn.functional.pad((frames * frames).cumsum(-1), [1, 0]) + corr_0 = sqrcs[..., -1:] + corr_tau = sqrcs.flip(-1)[..., :tau_max] - sqrcs[..., :tau_max] + diff = corr_0 + corr_tau - 2 * corr + + # cumulative mean normalized difference function (equation 8) + return ( + diff[..., 1:] + * torch.arange(1, diff.shape[-1], device=diff.device) + / torch.maximum( + diff[..., 1:].cumsum(-1), + torch.tensor(1e-5, device=diff.device), + ) + ) + + +def _search(cmdf: torch.Tensor, tau_max: int, threshold: float) -> torch.Tensor: + # mask all periods after the first cmdf below the threshold + # if none are below threshold (argmax=0), this is a non-periodic frame + first_below = (cmdf < threshold).int().argmax(-1, keepdim=True) + first_below = torch.where(first_below > 0, first_below, tau_max) + beyond_threshold = torch.arange(cmdf.shape[-1], device=cmdf.device) >= first_below + + # mask all periods with upward sloping cmdf to find the local minimum + increasing_slope = torch.nn.functional.pad(cmdf.diff() >= 0.0, [0, 1], value=1) + + # find the first period satisfying both constraints + return (beyond_threshold & increasing_slope).int().argmax(-1) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3250071955216f6abc505e6181fb59931baa8d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/activations.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..627eef2957174a2412c3d1a43bbfeb0d45f00503 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/activations.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +__all__ = ['Swish'] + + +class Swish(nn.Module): + """ + Swish activation function introduced in 'https://arxiv.org/abs/1710.05941' + """ + + def forward(self, x): + return x * torch.sigmoid(x) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/cleaners.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..9af9a41e0ef09d5f7521f2e495745279b96e3b50 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/cleaners.py @@ -0,0 +1,204 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import inflect +from unidecode import unidecode + +from nemo.utils import logging + +NUM_CHECK = re.compile(r'([$]?)(^|\s)(\S*[0-9]\S*)(?=(\s|$)((\S*)(\s|$))?)') + +TIME_CHECK = re.compile(r'([0-9]{1,2}):([0-9]{2})(am|pm)?') +CURRENCY_CHECK = re.compile(r'\$') +ORD_CHECK = re.compile(r'([0-9]+)(st|nd|rd|th)') +THREE_CHECK = re.compile(r'([0-9]{3})([.,][0-9]{1,2})?([!.?])?$') +DECIMAL_CHECK = re.compile(r'([.,][0-9]{1,2})$') + +ABBREVIATIONS_COMMON = [ + (re.compile('\\b%s\\.' % x[0]), x[1]) + for x in [ + ("ms", "miss"), + ("mrs", "misess"), + ("mr", "mister"), + ("messrs", "messeurs"), + ("dr", "doctor"), + ("drs", "doctors"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("sr", "senior"), + ("rev", "reverend"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("maj", "major"), + ("col", "colonel"), + ("lt", "lieutenant"), + ("gen", "general"), + ("prof", "professor"), + ("lb", "pounds"), + ("rep", "representative"), + ("st", "street"), + ("ave", "avenue"), + ("etc", "et cetera"), + ("jan", "january"), + ("feb", "february"), + ("mar", "march"), + ("apr", "april"), + ("jun", "june"), + ("jul", "july"), + ("aug", "august"), + ("sep", "september"), + ("oct", "october"), + ("nov", "november"), + ("dec", "december"), + ] +] + +ABBREVIATIONS_EXPANDED = [ + (re.compile('\\b%s\\.' % x[0]), x[1]) + for x in [ + ("ltd", "limited"), + ("fig", "figure"), + ("figs", "figures"), + ("gent", "gentlemen"), + ("ft", "fort"), + ("esq", "esquire"), + ("prep", "preperation"), + ("bros", "brothers"), + ("ind", "independent"), + ("mme", "madame"), + ("pro", "professional"), + ("vs", "versus"), + ("inc", "include"), + ] +] + +inflect = inflect.engine() + + +def clean_text(string, table, punctuation_to_replace): + warn_common_chars(string) + string = unidecode(string) + string = string.lower() + string = re.sub(r'\s+', " ", string) + string = clean_numbers(string) + string = clean_abbreviations(string) + string = clean_punctuations(string, table, punctuation_to_replace) + string = re.sub(r'\s+', " ", string).strip() + return string + + +def warn_common_chars(string): + if re.search(r'[£€]', string): + logging.warning("Your transcript contains one of 'Ā£' or '€' which we do not currently handle") + + +def clean_numbers(string): + cleaner = NumberCleaner() + string = NUM_CHECK.sub(cleaner.clean, string) + return string + + +def clean_abbreviations(string, expanded=False): + for regex, replacement in ABBREVIATIONS_COMMON: + string = re.sub(regex, replacement, string) + if expanded: + for regex, replacement in ABBREVIATIONS_EXPANDED: + string = re.sub(regex, replacement, string) + return string + + +def clean_punctuations(string, table, punctuation_to_replace): + for punc, replacement in punctuation_to_replace.items(): + string = re.sub('\\{}'.format(punc), " {} ".format(replacement), string) + string = string.translate(table) + return string + + +class NumberCleaner: + def __init__(self): + super().__init__() + self.reset() + + def reset(self): + self.curr_num = [] + self.currency = None + + def format_final_number(self, whole_num, decimal): + if self.currency: + return_string = inflect.number_to_words(whole_num) + return_string += " dollar" if whole_num == 1 else " dollars" + if decimal: + return_string += " and " + inflect.number_to_words(decimal) + return_string += " cent" if whole_num == decimal else " cents" + self.reset() + return return_string + + self.reset() + if decimal: + whole_num += "." + decimal + return inflect.number_to_words(whole_num) + else: + # Check if there are non-numbers + def convert_to_word(match): + return " " + inflect.number_to_words(match.group(0)) + " " + + return re.sub(r'[0-9,]+', convert_to_word, whole_num) + + def clean(self, match): + ws = match.group(2) + number = match.group(3) + _proceeding_symbol = match.group(7) + + time_match = TIME_CHECK.match(number) + if time_match: + string = ws + inflect.number_to_words(time_match.group(1)) + "{}{}" + mins = int(time_match.group(2)) + min_string = "" + if mins != 0: + min_string = " " + inflect.number_to_words(time_match.group(2)) + ampm_string = "" + if time_match.group(3): + ampm_string = " " + time_match.group(3) + return string.format(min_string, ampm_string) + + ord_match = ORD_CHECK.match(number) + if ORD_CHECK.match(number): + return ws + inflect.number_to_words(ord_match.group(0)) + + if self.currency is None: + # Check if it is a currency + self.currency = match.group(1) or CURRENCY_CHECK.match(number) + + # Check to see if next symbol is a number + # If it is a number and it has 3 digits, then it is probably a + # continuation + three_match = THREE_CHECK.match(match.group(6)) + if three_match: + self.curr_num.append(number) + return " " + # Else we can output + else: + # Check for decimals + whole_num = "".join(self.curr_num) + number + decimal = None + decimal_match = DECIMAL_CHECK.search(whole_num) + if decimal_match: + decimal = decimal_match.group(1)[1:] + whole_num = whole_num[: -len(decimal) - 1] + whole_num = re.sub(r'\.', '', whole_num) + return ws + self.format_final_number(whole_num, decimal) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/collections.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/collections.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ee89e504920c432cd936e694df58ff74e562a9 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/collections.py @@ -0,0 +1,358 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os +from typing import Any, Dict, List, Optional, Union + +import pandas as pd + +from nemo.collections.asr.parts import manifest, parsers +from nemo.utils import logging + + +class _Collection(collections.UserList): + """List of parsed and preprocessed data.""" + + OUTPUT_TYPE = None # Single element output type. + + +class Text(_Collection): + """Simple list of preprocessed text entries, result in list of tokens.""" + + OUTPUT_TYPE = collections.namedtuple('TextEntity', 'tokens') + + def __init__(self, texts: List[str], parser: parsers.CharParser): + """Instantiates text manifest and do the preprocessing step. + + Args: + texts: List of raw texts strings. + parser: Instance of `CharParser` to convert string to tokens. + """ + + data, output_type = [], self.OUTPUT_TYPE + for text in texts: + tokens = parser(text) + + if tokens is None: + logging.warning("Fail to parse '%s' text line.", text) + continue + + data.append(output_type(tokens)) + + super().__init__(data) + + +class FromFileText(Text): + """Another form of texts manifest with reading from file.""" + + def __init__(self, file: str, parser: parsers.CharParser): + """Instantiates text manifest and do the preprocessing step. + + Args: + file: File path to read from. + parser: Instance of `CharParser` to convert string to tokens. + """ + + texts = self.__parse_texts(file) + + super().__init__(texts, parser) + + @staticmethod + def __parse_texts(file: str) -> List[str]: + if not os.path.exists(file): + raise ValueError('Provided texts file does not exists!') + + _, ext = os.path.splitext(file) + if ext == '.csv': + texts = pd.read_csv(file)['transcript'].tolist() + elif ext == '.json': # Not really a correct json. + texts = list(item['text'] for item in manifest.item_iter(file)) + else: + with open(file, 'r') as f: + texts = f.readlines() + + return texts + + +class AudioText(_Collection): + """List of audio-transcript text correspondence with preprocessing.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='AudioTextEntity', field_names='id audio_file duration text_tokens offset text_raw speaker orig_sr aux_tokens', + ) + + def __init__( + self, + ids: List[int], + audio_files: List[str], + durations: List[float], + texts: List[str], + offsets: List[str], + speakers: List[Optional[int]], + orig_sampling_rates: List[Optional[int]], + parser: parsers.CharParser, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + dup_factor: int = 1, + parse_online: bool = False, + aux_parser = None, + ): + """Instantiates audio-text manifest with filters and preprocessing. + + Args: + ids: List of examples positions. + audio_files: List of audio files. + durations: List of float durations. + texts: List of raw text transcripts. + offsets: List of duration offsets or None. + speakers: List of optional speakers ids. + orig_sampling_rates: List of original sampling rates of audio files. + parser: Instance of `CharParser` to convert string to tokens. + min_duration: Minimum duration to keep entry with (default: None). + max_duration: Maximum duration to keep entry with (default: None). + max_number: Maximum number of samples to collect. + do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data. + """ + + self.parse_online = parse_online + self.aux_parser = aux_parser + if self.aux_parser: + assert self.parse_online + if self.parse_online: + print('parse online!') + self.parser = parser + + output_type = self.OUTPUT_TYPE + data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0 + if index_by_file_id: + self.mapping = {} + + for id_, audio_file, duration, offset, text, speaker, orig_sr in zip( + ids, audio_files, durations, offsets, texts, speakers, orig_sampling_rates + ): + # Duration filters. + if min_duration is not None and duration < min_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if parser is not None: + text_tokens = parser(text) + if text_tokens is None: + duration_filtered += duration + num_filtered += 1 + continue + else: + text_tokens = None + + total_duration += duration + + aux_tokens = None + data.append(output_type(id_, audio_file, duration, text_tokens, offset, text, speaker, orig_sr, aux_tokens)) + if index_by_file_id: + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + self.mapping[file_id] = len(data) - 1 + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + + if dup_factor > 1: + assert not index_by_file_id + data = [d_i for d_i in data for _ in range(dup_factor)] + logging.info("Dataset duplicated %d times", dup_factor) + + super().__init__(data) + + def get_item_online(self, i): + assert self.parse_online + data = super().__getitem__(i) + + if self.aux_parser: + aux_tokens = self.aux_parser(data.text_raw) + data = data._replace(aux_tokens=aux_tokens) + + text_tokens = self.parser(data.text_raw) + return data._replace(text_tokens=text_tokens) + + def __getitem__(self, i): + if self.parse_online: + return self.get_item_online(i) + else: + return super().__getitem__(i) + + +class ASRAudioText(AudioText): + """`AudioText` collector from asr structured json files.""" + + def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): + """Parse lists of audio files, durations and transcripts texts. + + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + *args: Args to pass to `AudioText` constructor. + **kwargs: Kwargs to pass to `AudioText` constructor. + """ + + ids, audio_files, durations, texts, offsets, speakers, orig_srs = [], [], [], [], [], [], [] + for item in manifest.item_iter(manifests_files): + ids.append(item['id']) + audio_files.append(item['audio_file']) + durations.append(item['duration']) + texts.append(item['text']) + offsets.append(item['offset']) + speakers.append(item['speaker']) + orig_srs.append(item['orig_sr']) + + super().__init__(ids, audio_files, durations, texts, offsets, speakers, orig_srs, *args, **kwargs) + + +class SpeechLabel(_Collection): + """List of audio-label correspondence with preprocessing.""" + + OUTPUT_TYPE = collections.namedtuple(typename='SpeechLabelEntity', field_names='audio_file duration label offset',) + + def __init__( + self, + audio_files: List[str], + durations: List[float], + labels: List[Union[int, str]], + offsets: List[Optional[float]], + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + ): + """Instantiates audio-label manifest with filters and preprocessing. + + Args: + audio_files: List of audio files. + durations: List of float durations. + labels: List of labels. + offsets: List of offsets or None. + min_duration: Minimum duration to keep entry with (default: None). + max_duration: Maximum duration to keep entry with (default: None). + max_number: Maximum number of samples to collect. + do_sort_by_duration: True if sort samples list by duration. + """ + + output_type = self.OUTPUT_TYPE + data, duration_filtered = [], 0.0 + for audio_file, duration, command, offset in zip(audio_files, durations, labels, offsets): + # Duration filters. + if min_duration is not None and duration < min_duration: + duration_filtered += duration + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + continue + + data.append(output_type(audio_file, duration, command, offset)) + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + data.sort(key=lambda entity: entity.duration) + + logging.info( + "Filtered duration for loading collection is %f.", duration_filtered, + ) + self.uniq_labels = sorted(set(map(lambda x: x.label, data))) + logging.info("# {} files loaded accounting to # {} labels".format(len(data), len(self.uniq_labels))) + + super().__init__(data) + + +class ASRSpeechLabel(SpeechLabel): + """`SpeechLabel` collector from structured json files.""" + + def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): + """Parse lists of audio files, durations and transcripts texts. + + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + *args: Args to pass to `SpeechLabel` constructor. + **kwargs: Kwargs to pass to `SpeechLabel` constructor. + """ + audio_files, durations, labels, offsets = [], [], [], [] + + for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item): + audio_files.append(item['audio_file']) + durations.append(item['duration']) + labels.append(item['label']) + offsets.append(item['offset']) + + super().__init__(audio_files, durations, labels, offsets, *args, **kwargs) + + def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: + item = json.loads(line) + + # Audio file + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + else: + raise ValueError( + f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." + ) + item['audio_file'] = os.path.expanduser(item['audio_file']) + + # Duration. + if 'duration' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") + + # Label. + if 'command' in item: + item['label'] = item.pop('command') + elif 'target' in item: + item['label'] = item.pop('target') + elif 'label' in item: + pass + else: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper label key.") + + item = dict( + audio_file=item['audio_file'], + duration=item['duration'], + label=item['label'], + offset=item.get('offset', None), + ) + + return item diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/compute_wer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/compute_wer.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a1bac05730ecfd2673575c9d4267997fa53857 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/compute_wer.py @@ -0,0 +1,63 @@ +import sys +from collections import OrderedDict + +import pandas as pd + +from nemo.collections.asr.metrics.wer_bpe import add_delim_space +from nemo.collections.asr.parts.simple_wer_v2 import SimpleWER + + +def analyze(texts, output_html_path=None, add_zh_space=False): + # def remove_sep(inp_text): + # if SEP in inp_text: + # inp_text = ' '.join(sp for sp in inp_text.split() if sp != SEP) + # return inp_text + wer_obj = SimpleWER( + preprocess_handler=None) + + total_empty_lev = 0 + for fname, true_text, pred_text in texts: + if add_zh_space: + true_text = add_delim_space(true_text) + pred_text = add_delim_space(pred_text) + # if ignore_sep: + # assert SEP not in pred_text + wer_obj.AddHypRef(pred_text, true_text) + + if not pred_text: + print('found empty pred: {}, {}'.format(fname, true_text)) + total_empty_lev += len(true_text.split()) + + str_summary, str_details, _, (wer, total_error, nref) = wer_obj.GetSummaries() + + print('empty token num: {}, empty error ratio: {}'.format(total_empty_lev, total_empty_lev/total_error if total_error > 0 else 0)) + + if output_html_path: + wer_obj.write_html(output_html_path) + return (str_summary, str_details), (wer, total_error, nref) + + +def calc_wer(true_fp, pred_fp): + true_df = pd.read_csv(true_fp, index_col='wav_filename', usecols=['wav_filename', 'transcript'], encoding='utf-8') + + true_dic = true_df['transcript'].to_dict() + + pred_dic = pd.read_csv(pred_fp, index_col='wav_filename', encoding='utf-8', keep_default_na=False)['predicted_transcript'].to_dict(into=OrderedDict) + + assert true_dic.keys() == pred_dic.keys() + + texts = [] + for fname, pred_text in pred_dic.items(): + texts.append([fname, true_dic[fname], pred_text]) + + (str_summary, str_details), (total_wer, total_word_lev, total_word_count) = analyze(texts, + output_html_path=pred_fp + '_diagnosis.html') + print(str_details) + print(str_summary) + return total_wer, (total_word_lev, total_word_count) + + +if __name__ == '__main__': + true_fp, pred_fp = sys.argv[1:] + + calc_wer(true_fp=true_fp, pred_fp=pred_fp) \ No newline at end of file diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/conformer_modules.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/conformer_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e0079b79c47ea4210eacaf35d10f3254c8d9114d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/conformer_modules.py @@ -0,0 +1,185 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from torch import nn as nn +from torch.nn import LayerNorm + +from nemo.collections.asr.parts.activations import Swish +from nemo.collections.asr.parts.multi_head_attention import MultiHeadAttention, RelPositionMultiHeadAttention + +__all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerEncoderBlock'] + + +class ConformerEncoderBlock(torch.nn.Module): + """A single block of the Conformer encoder. + + Args: + d_model (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward + d_ff (int): hidden dimension of PositionwiseFeedForward + n_heads (int): number of heads for multi-head attention + conv_kernel_size (int): kernel size for depthwise convolution in convolution module + dropout (float): dropout probabilities for linear layers + dropout_att (float): dropout probabilities for attention distributions + """ + + def __init__( + self, + d_model, + d_ff, + conv_kernel_size, + self_attention_model, + n_heads, + dropout, + dropout_att, + pos_bias_u, + pos_bias_v, + ): + super(ConformerEncoderBlock, self).__init__() + + self.self_attention_model = self_attention_model + self.n_heads = n_heads + self.fc_factor = 0.5 + + # first feed forward module + self.norm_feed_forward1 = LayerNorm(d_model) + self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) + + # convolution module + self.norm_conv = LayerNorm(d_model) + self.conv = ConformerConvolution(d_model=d_model, kernel_size=conv_kernel_size) + + # multi-headed self-attention module + self.norm_self_att = LayerNorm(d_model) + if self_attention_model == 'rel_pos': + self.self_attn = RelPositionMultiHeadAttention( + n_head=n_heads, n_feat=d_model, dropout_rate=dropout_att, pos_bias_u=pos_bias_u, pos_bias_v=pos_bias_v + ) + elif self_attention_model == 'abs_pos': + self.self_attn = MultiHeadAttention(n_head=n_heads, n_feat=d_model, dropout_rate=dropout_att) + else: + raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") + + # second feed forward module + self.norm_feed_forward2 = LayerNorm(d_model) + self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) + + self.dropout = nn.Dropout(dropout) + self.norm_out = LayerNorm(d_model) + + def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None): + """ + Args: + x (torch.Tensor): input signals (B, T, d_model) + att_mask (torch.Tensor): attention masks(B, T, T) + pos_emb (torch.Tensor): (L, 1, d_model) + pad_mask (torch.tensor): padding mask + Returns: + x (torch.Tensor): (B, T, d_model) + """ + residual = x + x = self.norm_feed_forward1(x) + x = self.feed_forward1(x) + x = self.fc_factor * self.dropout(x) + residual + + residual = x + x = self.norm_self_att(x) + if self.self_attention_model == 'rel_pos': + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) + elif self.self_attention_model == 'abs_pos': + x = self.self_attn(query=x, key=x, value=x, mask=att_mask) + else: + x = None + x = self.dropout(x) + residual + + residual = x + x = self.norm_conv(x) + x = self.conv(x, pad_mask) + x = self.dropout(x) + residual + + residual = x + x = self.norm_feed_forward2(x) + x = self.feed_forward2(x) + x = self.fc_factor * self.dropout(x) + residual + + x = self.norm_out(x) + return x + + +class ConformerConvolution(nn.Module): + """The convolution module for the Conformer model. + Args: + d_model (int): hidden dimension + kernel_size (int): kernel size for depthwise convolution + """ + + def __init__(self, d_model, kernel_size): + super(ConformerConvolution, self).__init__() + assert (kernel_size - 1) % 2 == 0 + self.d_model = d_model + + self.pointwise_conv1 = nn.Conv1d( + in_channels=d_model, out_channels=d_model * 2, kernel_size=1, stride=1, padding=0, bias=True + ) + self.depthwise_conv = nn.Conv1d( + in_channels=d_model, + out_channels=d_model, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=d_model, + bias=True, + ) + self.batch_norm = nn.BatchNorm1d(d_model) + self.activation = Swish() + self.pointwise_conv2 = nn.Conv1d( + in_channels=d_model, out_channels=d_model, kernel_size=1, stride=1, padding=0, bias=True + ) + + def forward(self, x, pad_mask=None): + x = x.transpose(1, 2) + x = self.pointwise_conv1(x) + + x = nn.functional.glu(x, dim=1) + + if pad_mask is not None: + x.masked_fill_(pad_mask.unsqueeze(1), 0.0) + x = self.depthwise_conv(x) + + x = self.batch_norm(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) + x = x.transpose(1, 2) + return x + + +class ConformerFeedForward(nn.Module): + """ + feed-forward module of Conformer model. + """ + + def __init__(self, d_model, d_ff, dropout, activation=Swish()): + super(ConformerFeedForward, self).__init__() + self.linear1 = nn.Linear(d_model, d_ff) + self.activation = activation + self.dropout = nn.Dropout(p=dropout) + self.linear2 = nn.Linear(d_ff, d_model) + + def forward(self, x): + x = self.linear1(x) + x = self.activation(x) + x = self.dropout(x) + x = self.linear2(x) + return x diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/convolution_layers.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/convolution_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..e235f7d73474b8cd6f29c08b0b98aa19bf8157de --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/convolution_layers.py @@ -0,0 +1,370 @@ +import torch +from torch import nn as nn +import torch.nn.functional as F + +conv_dic = {'1d': torch.nn.Conv1d, '2d': torch.nn.Conv2d} + +act_dic = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU} + + +class ProjUpsampling(nn.Module): + def __init__(self, in_channels, filters, kernel_size, *, rate, norm_type, act_func, + dropout=0.0, + padding='same', + use_tf_pad=True, + ln_eps=1e-5, + bias=True): + super(ProjUpsampling, self).__init__() + + self.upsample_rate = rate + self.filters = filters + self.proj = ConvNormAct(in_channels=in_channels, filters=self.filters * self.upsample_rate, kernel_size=kernel_size, + stride=(1,), dilation=(1,), norm_type=None, act_func=None, + conv_type='1d', dropout=0.0, + padding=padding, use_tf_pad=use_tf_pad, ln_eps=ln_eps, gn_groups=None, bias=bias) + + assert norm_type is None or norm_type == 'ln' + self.norm = get_norm(norm_type, '1d', self.filters, ln_eps=ln_eps, gn_groups=None) + self.norm_type = norm_type + self.act = identity if act_func is None else act_dic[act_func]() + self.drop = identity if dropout == 0 else nn.Dropout(p=dropout) + + def forward(self, x, lens): + pad_mask = create_pad_mask(lens, max_len=x.size(2)) + output, lens, _ = self.proj(x, lens, pad_mask=pad_mask) + output = output.transpose(1, 2) + B, T, C = output.size() + output = output.reshape(B, T * self.upsample_rate, self.filters) + lens = lens * self.upsample_rate + output = self.norm(output) + output = self.act(output) + output = self.drop(output) + output = output.transpose(1, 2) + return output, lens + + +class ConvNormAct(nn.Module): + def __init__(self, in_channels, filters, kernel_size, stride, dilation, norm_type, act_func, + conv_type, + dropout=0.0, + padding='same', + use_tf_pad=True, + ln_eps=1e-5, + gn_groups=None, + bias=None, + residual=False): + super(ConvNormAct, self).__init__() + + if bias is None: + bias = norm_type is None + + self.conv = Conv(in_channels, filters, tuple(kernel_size), + stride=tuple(stride), + padding=padding, + dilation=tuple(dilation), + bias=bias, + conv_type=conv_type, + use_tf_pad=use_tf_pad) + self.proj_conv = None + assert conv_type in ['1d', '2d'] + self.norm = get_norm(norm_type, conv_type, filters, ln_eps=ln_eps, gn_groups=gn_groups) + self.norm_type = norm_type + self.act = identity if act_func is None else act_dic[act_func]() + self.drop = identity if dropout == 0 else nn.Dropout(p=dropout) + self.residual = residual + if self.residual: + print('residual ConvNormAct!') + assert in_channels == filters + assert stride[0] == 1 + + def forward(self, x, lens, pad_mask=None): + # x: [B, C, T] or [B, C, T, F] + + output, lens, pad_mask = self.conv(x, lens, pad_mask) + if self.norm_type == 'ln': + output = torch.transpose(output, -1, -2) + output = self.norm(output) + if self.norm_type == 'ln': + output = torch.transpose(output, -1, -2) + output = self.act(output) + output = self.drop(output) + + if self.residual: + output = output + x + + return output, lens, pad_mask + + def update_out_seq_lens(self, lens): + return self.conv.update_out_seq_lens(lens) + + +class ConvFFN(nn.Module): + def __init__(self, in_channels, filters, kernel_size, act_func, + norm_type='post_ln', + dropout=0.0, + padding='same', + use_tf_pad=True, + ln_eps=1e-5, + ): + super().__init__() + + self.conv_1 = Conv(in_channels, filters, kernel_size, + padding=padding, use_tf_pad=use_tf_pad) + self.norm_type = norm_type + self.norm = get_norm('ln', '1d', in_channels, ln_eps=ln_eps) + self.act = identity if act_func is None else act_dic[act_func]() + + if dropout == 0: + self.drop_act = identity + else: + self.drop_act = nn.Dropout(p=dropout) + + self.conv_2 = Conv(filters, in_channels, (1,), padding=padding, use_tf_pad=use_tf_pad) + + if dropout == 0: + self.drop = identity + else: + self.drop = nn.Dropout(p=dropout) + + def forward(self, inputs, lens, pad_mask=None): + # x: [B, C, T] + if self.norm_type == 'pre_ln': + output = torch.transpose(inputs, -1, -2) + output = self.norm(output) + output = torch.transpose(output, -1, -2) + else: + assert self.norm_type == 'post_ln' + output = inputs + + output, lens, pad_mask = self.conv_1(output, lens, pad_mask) + + output = self.act(output) + output = self.drop_act(output) + + output, lens, pad_mask = self.conv_2(output, lens, pad_mask) + + output = self.drop(output) + + output = output + inputs + + if self.norm_type == 'post_ln': + output = torch.transpose(output, -1, -2) + output = self.norm(output) + output = torch.transpose(output, -1, -2) + + return output, lens, pad_mask + + def update_out_seq_lens(self, lens): + return self.conv_1.update_out_seq_lens(lens) + + +class ConvBlock(nn.Module): + def __init__(self, feat_in, use_conv_mask, conv_layers=None, conv_ffn_layers=None, output_proj_dim=None, + use_tf_pad: bool=False, ln_eps: float = 1e-5): + super().__init__() + + self.use_conv_mask = use_conv_mask + + prev_out_channels = feat_in + + self.layers = nn.ModuleList() + for conv_cfg_i in conv_layers: + layer = ConvNormAct(in_channels=prev_out_channels, + conv_type='1d', + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **conv_cfg_i) + prev_out_channels = conv_cfg_i.filters + self.layers.append(layer) + + if conv_ffn_layers: + for conv_cfg_i in conv_ffn_layers: + layer = ConvFFN(in_channels=prev_out_channels, + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **conv_cfg_i) + prev_out_channels = conv_cfg_i.filters + self.layers.append(layer) + + if output_proj_dim: + layer = Conv(prev_out_channels, output_proj_dim, 1, use_tf_pad=use_tf_pad) + self.layers.append(layer) + + def forward(self, input, length): + # [B, C, T] + output = input + + if self.use_conv_mask: + pad_mask = create_pad_mask(length, max_len=output.size(2)) + else: + pad_mask = None + + for layer in self.layers: + output, length, pad_mask = layer(output, length, pad_mask=pad_mask) + + return output, length + + +def get_norm(norm_type, conv_type, filters, ln_eps=1e-5, gn_groups=None): + if norm_type == 'bn': + if conv_type == '2d': + norm = nn.BatchNorm2d(filters, momentum=0.01, eps=1e-3) + else: + norm = nn.BatchNorm1d(filters, momentum=0.01, eps=1e-3) + elif norm_type == 'ln': + assert conv_type != '2d' + norm = nn.LayerNorm(filters, eps=ln_eps) + elif norm_type == 'gn': + assert gn_groups is not None + norm = nn.GroupNorm(gn_groups, filters) + else: + assert norm_type is None, norm_type + norm = identity + return norm + + +# conv wrapper supports same padding, tf style padding, track length change during subsampling +class Conv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, + padding='same', + dilation=1, + bias=True, + conv_type='1d', + use_tf_pad=False): + super(Conv, self).__init__() + + self.conv_type = conv_type + self.is_2d_conv = self.conv_type == '2d' + + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) + if self.is_2d_conv: + kernel_size = kernel_size * 2 + + if isinstance(stride, int): + stride = (stride,) + if self.is_2d_conv: + stride = stride * 2 + + self.padding = padding + + if isinstance(dilation, int): + dilation = (dilation,) + if self.is_2d_conv: + dilation = dilation * 2 + assert dilation == (1,) or dilation == (1, 1) + + # assert use_tf_pad + self.use_tf_pad = use_tf_pad + if self.use_tf_pad: + self.pad_num, self.even_pad_num = get_tf_pad(kernel_size, stride) + + self.conv = conv_dic[self.conv_type](in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.get_padding_num(kernel_size, stride, dilation), + bias=bias) + self.need_pad = kernel_size[0] > 1 or (len(kernel_size) == 2 and kernel_size[1] > 1) + self.need_pad_mask = kernel_size[0] > 1 + assert stride[0] >= 1 + self.subsample_factor = stride[0] + + def forward(self, x, lens, pad_mask=None): + # x: [B, C, T] or [B, C, T, F] + if pad_mask is not None and self.need_pad_mask: + if self.is_2d_conv: + x = x.masked_fill(pad_mask.unsqueeze(1).unsqueeze(-1), 0.0) + else: + x = x.masked_fill(pad_mask.unsqueeze(1), 0.0) + + if self.use_tf_pad and self.need_pad: + x = self.pad_like_tf(x) + + output = self.conv(x) + + if self.subsample_factor > 1: + lens = self.update_out_seq_lens(lens) + pad_mask = create_pad_mask(lens, max_len=output.size(2)) + + return output, lens, pad_mask + + def get_padding_num(self, kernel_size, stride, dilation): + if self.padding == 'same': + if self.use_tf_pad: + padding_val = 0 + else: + assert not self.use_tf_pad + # assert self.conv_type == '1d' + if self.is_2d_conv: + assert kernel_size[0] == kernel_size[1] + padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0]) + else: + raise ValueError("currently only 'same' padding is supported") + return padding_val + + def update_out_seq_lens(self, lens): + t = 0 # axis of time dimension + if self.padding == 'same': + if self.use_tf_pad: + lens = (lens + self.conv.stride[t] - 1) // self.conv.stride[t] + else: + # todo: verify this in pytorch + lens = (lens + 2 * self.conv.padding[t] - self.conv.dilation[t] * (self.conv.kernel_size[t] - 1) - 1) // self.conv.stride[t] + 1 + else: + assert self.padding == 'valid' and self.use_tf_pad + lens = (lens - self.conv.kernel_size[t] + self.conv.stride[t]) // self.conv.stride[t] + return lens + + def pad_like_tf(self, x): + if self.is_2d_conv: + if x.size(-1) % 2 == 0: + w_pad_num = self.even_pad_num[1] + else: + w_pad_num = self.pad_num[1] + if x.size(-2) % 2 == 0: + h_pad_num = self.even_pad_num[0] + else: + h_pad_num = self.pad_num[0] + pad_num = w_pad_num + h_pad_num + else: + if x.size(-2) % 2 == 0: + pad_num = self.even_pad_num[0] + else: + pad_num = self.pad_num[0] + + return F.pad(x, pad_num) + + +def get_same_padding(kernel_size, stride, dilation): + # todo: support 2d conv + if stride > 1 and dilation > 1: + raise ValueError("Only stride OR dilation may be greater than 1") + if dilation > 1: + return (dilation * kernel_size) // 2 - 1 + return kernel_size // 2 + + +def get_tf_pad(kernel_size, stride): + pad_config = [] + even_pad_config = [] + for i in range(len(kernel_size)): + assert kernel_size[i] % 2 == 1 + pad_num_i = kernel_size[i] // 2 + pad_config.append([pad_num_i, pad_num_i]) + if stride[i] == 2: + even_pad_config.append([pad_num_i - 1, pad_num_i]) + else: + assert stride[i] == 1 + even_pad_config.append([pad_num_i, pad_num_i]) + return pad_config, even_pad_config + + +def create_pad_mask(lens, max_len=None): + mask = torch.arange(max_len).to(lens.device) >= lens.unsqueeze(-1) + return mask + + +def identity(x): + return x diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/ctc_beam_search.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/ctc_beam_search.py new file mode 100644 index 0000000000000000000000000000000000000000..23e28472085a94eb3eecda13a2b4d54ae014b743 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/ctc_beam_search.py @@ -0,0 +1,152 @@ +import time + +import torch + +import numpy as np + + +class CTCBeamSearchDecoder: + def __init__( + self, + beam_size: int, + blank_index: int = 0, + temperature: float = 1.0, + combine_path: bool = True + ): + + self.blank_id = blank_index + # self.vocab_size = decoder_model.vocab_size + + if beam_size < 1: + raise ValueError("Beam search size cannot be less than 1!") + + self.beam_size = beam_size + + self.beam_stepwise_ln_alpha = 0. + self.beam_word_reward_ratio = 0. + if self.beam_word_reward_ratio > 0: + assert self.beam_stepwise_ln_alpha == 0 + self.beam_combine_path = combine_path + self.beam_temperature = temperature + + self.search_algorithm = self.beam_search + + def decode(self, logits, logits_len): + assert len(logits) == len(logits_len) + all_best_hyps = [] + all_nbest_hyps = [] + decode_st = time.perf_counter() + for i, (logits_i, logits_i_len) in enumerate(zip(logits, logits_len)): + st = time.perf_counter() + best_hyps = [] + nbest_hyps = [] + for batch_idx in range(logits_i.shape[0]): + inseq = logits_i[batch_idx:batch_idx + 1] + logitlen = logits_i_len[batch_idx] + + nbest_hyps_i = self.search_algorithm(inseq, logitlen) # sorted list of hypothesis + + best_hyps.append(nbest_hyps_i[0]) + nbest_hyps.append(nbest_hyps_i) + all_best_hyps.append(best_hyps) + all_nbest_hyps.append(nbest_hyps) + + et = time.perf_counter() + print('decoding {}/{}, took {:.2f}s, all {:.1f}s, avg {:.2f}s/it'.format(i + 1, len(logits), et - st, et - decode_st, + (et - decode_st) / (i + 1)), flush=True) + + return all_best_hyps, all_nbest_hyps + + def beam_search(self, logits, logits_len): + assert logits.shape[0] == 1 + vocab_size = logits.shape[-1] + + hyps = [Hyp(score=1.0, labels=tuple(), last_label=None)] + + blank_label_id = self.blank_id + + logits = torch.from_numpy(logits) + if self.beam_temperature != 1.0: + logits = logits / self.beam_temperature + prob = logits.softmax(dim=-1) + prob = prob.cpu().numpy().astype(np.float64) + + for t in range(int(logits_len)): + prob_t = prob[:, t:t+1, np.newaxis, :] + hyps_score = np.array([hyp_i.score for hyp_i in hyps]) + hyps_score = hyps_score[:, np.newaxis] + new_hyp_score = hyps_score * prob_t + # [B, T, beam_size, V] -> [B, 1, beam_size * V] + new_hyp_score = np.reshape(new_hyp_score, [new_hyp_score.shape[0], new_hyp_score.shape[1], -1]) + + # prob_t_topk_idx = np.argsort(new_hyp_score) + prob_t_topk_idx = np.argpartition(new_hyp_score, -self.beam_size) + prob_t_topk_idx = prob_t_topk_idx[0, 0, -self.beam_size:] + + unique_hyps = {} + for path_i in prob_t_topk_idx: + hyp_num = path_i // vocab_size + hyp_i = hyps[hyp_num] + label_i = path_i % vocab_size + + if label_i == hyp_i.last_label or label_i == blank_label_id: + hyp_i_labels = hyp_i.labels + else: + hyp_i_labels = hyp_i.labels + (label_i,) + new_hyp = Hyp(score=new_hyp_score[0, 0, path_i], + labels=hyp_i_labels, + last_label=label_i) + + new_hyp_hash = (new_hyp.labels, new_hyp.last_label) + if new_hyp_hash in unique_hyps: + if self.beam_combine_path: + unique_hyps[new_hyp_hash].score += new_hyp.score + else: + if new_hyp.score > unique_hyps[new_hyp_hash].score: + unique_hyps[new_hyp_hash] = new_hyp + else: + unique_hyps[new_hyp_hash] = new_hyp + + hyps = list(unique_hyps.values()) + assert len(hyps) <= self.beam_size + + return sorted(hyps, key=lambda x: x.score, reverse=True) + + +class Hyp: + __slots__ = ('score', 'labels', 'pred_net_state', 'pred_net_output', 'last_label') + + def __init__(self, score, labels=None, last_label=None): + self.score = score + self.labels = labels + self.last_label = last_label + + def length_norm_score(self, alpha): + return length_norm(np.log(self.score), len(self.labels), alpha) + + def word_reward_score(self, reward_ratio): + label_len = len(self.labels) + # if not self.blank: + # label_len -= 1 + return np.log(self.score) + label_len * reward_ratio + + +def length_norm(log_prob, label_len, alpha): + len_norm = (label_len + 5.) / 6. + if alpha != 1: + len_norm = len_norm ** alpha + return log_prob / len_norm + + +def get_length_normed_best(hyps): + best_hyp = None + best_hyp_score = None + score_len_normed = [] + for hyp_i in hyps: + length_normed_score = np.log(hyp_i.score) / (len(hyp_i.labels) + 1e-16) + score_len_normed.append(length_normed_score) + if best_hyp_score is None or length_normed_score > best_hyp_score: + best_hyp_score = length_normed_score + best_hyp = hyp_i + assert best_hyp is not None + return best_hyp, score_len_normed diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/features.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/features.py new file mode 100644 index 0000000000000000000000000000000000000000..559981b9bd4891722af93447e9a86f1be30ff858 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/features.py @@ -0,0 +1,455 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (c) 2018 Ryan Leary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# This file contains code artifacts adapted from https://github.com/ryanleary/patter +import math +from inspect import signature + +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from librosa.util import tiny +from torch.autograd import Variable +from torch_stft import STFT + +try: + import torch.cuda.amp + AMP_AVAILABLE = hasattr(torch.cuda.amp, 'autocast') +except ImportError: + AMP_AVAILABLE = False + +from nemo.collections.asr.parts.perturb import AudioAugmentor +from nemo.collections.asr.parts.segment import AudioSegment +from nemo.utils import logging + +CONSTANT = 1e-5 + + +def normalize_batch(x, seq_len, normalize_type): + if normalize_type == "per_feature": + x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) + x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) + for i in range(x.shape[0]): + if x[i, :, : seq_len[i]].shape[1] == 1: + raise ValueError( + "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " + "in torch.std() returning nan" + ) + x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1) + x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1) + # make sure x_std is not zero + x_std += CONSTANT + return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2) + elif normalize_type == "all_features": + x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) + x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) + for i in range(x.shape[0]): + x_mean[i] = x[i, :, : seq_len[i].item()].mean() + x_std[i] = x[i, :, : seq_len[i].item()].std() + # make sure x_std is not zero + x_std += CONSTANT + return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1) + elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: + x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) + x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) + return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2) + else: + return x + + +def splice_frames(x, frame_splicing): + """ Stacks frames together across feature dim + + input is batch_size, feature_dim, num_frames + output is batch_size, feature_dim*frame_splicing, num_frames + + """ + seq = [x] + for n in range(1, frame_splicing): + seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) + return torch.cat(seq, dim=1) + + +class WaveformFeaturizer(object): + def __init__(self, sample_rate=16000, int_values=False, augmentor=None, return_both=False): + self.augmentor = augmentor if augmentor is not None else AudioAugmentor() + self.sample_rate = sample_rate + self.int_values = int_values + self.return_both = return_both + + def max_augmentation_length(self, length): + return self.augmentor.max_augmentation_length(length) + + def process(self, file_path, offset=0, duration=0, trim=False, orig_sr=None, crop_size=None): + audio = AudioSegment.from_file( + file_path, + target_sr=self.sample_rate, + int_values=self.int_values, + offset=offset, + duration=duration, + trim=trim, + orig_sr=orig_sr, + ) + + if crop_size: + audio._samples = crop_to_max_size(audio._samples, crop_size) + + if self.return_both: + clean_audio = torch.tensor(audio._samples, dtype=torch.float) + return clean_audio, self.process_segment(audio) + else: + return self.process_segment(audio) + + def process_segment(self, audio_segment): + self.augmentor.perturb(audio_segment) + return torch.tensor(audio_segment.samples, dtype=torch.float) + + @classmethod + def from_config(cls, input_config, perturbation_configs=None): + if perturbation_configs is not None: + aa = AudioAugmentor.from_config(perturbation_configs) + else: + aa = None + + sample_rate = input_config.get("sample_rate", 16000) + int_values = input_config.get("int_values", False) + + return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa) + + +def crop_to_max_size(wav, target_size): + size = wav.shape[0] + diff = size - target_size + if diff <= 0: + return wav + + start = np.random.randint(0, diff + 1) + end = size - diff + start + return wav[start:end] + + +class FeaturizerFactory(object): + def __init__(self): + pass + + @classmethod + def from_config(cls, input_cfg, perturbation_configs=None): + return WaveformFeaturizer.from_config(input_cfg, perturbation_configs=perturbation_configs) + + +# Create helper class to patch forward func for use with AMP +class STFTPatch(STFT): + def forward(self, input_data): + return super().transform(input_data)[0] + + +# Create helper class for STFT that yields num_frames = num_samples // hop_length +class STFTExactPad(STFTPatch): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, *params, **kw_params): + super().__init__(*params, **kw_params) + self.pad_amount = (self.filter_length - self.hop_length) // 2 + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = librosa.filters.window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy(np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable(torch.from_numpy(window_sum), requires_grad=False) + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= self.filter_length / self.hop_length + + inverse_transform = inverse_transform[:, :, self.pad_amount :] + inverse_transform = inverse_transform[:, :, : -self.pad_amount :] + + return inverse_transform + + +class FilterbankFeatures(nn.Module): + """Featurizer that converts wavs to Mel Spectrograms. + See AudioToMelSpectrogramPreprocessor for args. + """ + + def __init__( + self, + sample_rate=16000, + n_window_size=320, + n_window_stride=160, + window="hann", + normalize="per_feature", + n_fft=None, + preemph=0.97, + nfilt=64, + lowfreq=0, + highfreq=None, + log=True, + log_zero_guard_type="add", + log_zero_guard_value=2 ** -24, + dither=CONSTANT, + dither_train_only=False, + pad_to=16, + max_duration=16.7, + frame_splicing=1, + stft_exact_pad=False, + stft_conv=False, + pad_value=0, + mag_power=2.0, + normalize_time_domain=False, + ): + super().__init__() + self.log_zero_guard_value = log_zero_guard_value + if ( + n_window_size is None + or n_window_stride is None + or not isinstance(n_window_size, int) + or not isinstance(n_window_stride, int) + or n_window_size <= 0 + or n_window_stride <= 0 + ): + raise ValueError( + f"{self} got an invalid value for either n_window_size or " + f"n_window_stride. Both must be positive ints." + ) + logging.info(f"PADDING: {pad_to}") + + self.win_length = n_window_size + self.hop_length = n_window_stride + self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) + self.stft_exact_pad = stft_exact_pad + self.stft_conv = stft_conv + + if stft_conv: + logging.info("STFT using conv") + if stft_exact_pad: + logging.info("STFT using exact pad") + self.stft = STFTExactPad(self.n_fft, self.hop_length, self.win_length, window) + else: + self.stft = STFTPatch(self.n_fft, self.hop_length, self.win_length, window) + else: + logging.info("STFT using torch") + torch_windows = { + 'hann': torch.hann_window, + 'hamming': torch.hamming_window, + 'blackman': torch.blackman_window, + 'bartlett': torch.bartlett_window, + 'none': None, + } + window_fn = torch_windows.get(window, None) + window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None + self.register_buffer("window", window_tensor) + if 'return_complex' not in signature(torch.stft).parameters: + self.stft = lambda x: torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=False if stft_exact_pad else True, + window=self.window.to(dtype=torch.float), + ) + else: + self.stft = lambda x: torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=False if stft_exact_pad else True, + window=self.window.to(dtype=torch.float), + return_complex=False, + ) + + self.normalize = normalize + self.log = log + self.dither = dither + self.dither_train_only = dither_train_only + self.frame_splicing = frame_splicing + self.nfilt = nfilt + self.preemph = preemph + self.pad_to = pad_to + highfreq = highfreq or sample_rate / 2 + + filterbanks = torch.tensor( + librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float + ).unsqueeze(0) + self.register_buffer("fb", filterbanks) + + # Calculate maximum sequence length + max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float)) + max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 + self.max_length = max_length + max_pad + self.pad_value = pad_value + self.mag_power = mag_power + + self.normalize_time_domain = normalize_time_domain + + # We want to avoid taking the log of zero + # There are two options: either adding or clamping to a small value + if log_zero_guard_type not in ["add", "clamp"]: + raise ValueError( + f"{self} received {log_zero_guard_type} for the " + f"log_zero_guard_type parameter. It must be either 'add' or " + f"'clamp'." + ) + # log_zero_guard_value is the the small we want to use, we support + # an actual number, or "tiny", or "eps" + self.log_zero_guard_type = log_zero_guard_type + logging.debug(f"sr: {sample_rate}") + logging.debug(f"n_fft: {self.n_fft}") + logging.debug(f"win_length: {self.win_length}") + logging.debug(f"hop_length: {self.hop_length}") + logging.debug(f"n_mels: {nfilt}") + logging.debug(f"fmin: {lowfreq}") + logging.debug(f"fmax: {highfreq}") + + def log_zero_guard_value_fn(self, x): + if isinstance(self.log_zero_guard_value, str): + if self.log_zero_guard_value == "tiny": + return torch.finfo(x.dtype).tiny + elif self.log_zero_guard_value == "eps": + return torch.finfo(x.dtype).eps + else: + raise ValueError( + f"{self} received {self.log_zero_guard_value} for the " + f"log_zero_guard_type parameter. It must be either a " + f"number, 'tiny', or 'eps'" + ) + else: + return self.log_zero_guard_value + + def get_seq_len(self, seq_len): + return torch.ceil(seq_len / self.hop_length).to(dtype=torch.long) + + @property + def filter_banks(self): + return self.fb + + @torch.no_grad() + def forward(self, x, seq_len): + if self.normalize_time_domain: + x = normalize_time_domain(x) + + seq_len = self.get_seq_len(seq_len.float()) + + if self.stft_exact_pad and not self.stft_conv: + p = (self.n_fft - self.hop_length) // 2 + x = torch.nn.functional.pad(x.unsqueeze(1), (p, p), "reflect").squeeze(1) + + # dither + if self.dither > 0 and (not self.dither_train_only or self.training): + x += self.dither * torch.randn_like(x) + + # do preemphasis + if self.preemph is not None: + x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) + + if AMP_AVAILABLE: + # disable autocast to get full range of stft values + with torch.cuda.amp.autocast(enabled=False): + x = self.stft(x) + else: + x = self.stft(x) + + # torch returns real, imag; so convert to magnitude + if not self.stft_conv: + x = torch.sqrt(x.pow(2).sum(-1)) + + # get power spectrum + if self.mag_power != 1.0: + x = x.pow(self.mag_power) + + # dot with filterbank energies + x = torch.matmul(self.fb.to(x.dtype), x) + + # log features if required + if self.log: + if self.log_zero_guard_type == "add": + x = torch.log(x + self.log_zero_guard_value_fn(x)) + elif self.log_zero_guard_type == "clamp": + x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) + else: + raise ValueError("log_zero_guard_type was not understood") + + # frame splicing if required + if self.frame_splicing > 1: + x = splice_frames(x, self.frame_splicing) + + # normalize if required + if self.normalize: + x = normalize_batch(x, seq_len, normalize_type=self.normalize) + + # mask to zero any values beyond seq_len in batch, pad to multiple of + # `pad_to` (for efficiency) + pad_to = self.pad_to + if pad_to != 0: + max_len = x.size(-1) + mask = torch.arange(max_len).to(x.device) + mask = mask.expand(x.size(0), max_len) >= seq_len.unsqueeze(1) + x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value) + del mask + else: + max_seq_len = seq_len.max() + if x.shape[-1] != max_seq_len: + x = x.narrow(dim=-1, start=0, length=max_seq_len).contiguous() + if pad_to == "max": + x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value) + elif pad_to > 0: + pad_amt = x.size(-1) % pad_to + if pad_amt != 0: + x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) + + return x, seq_len + + +def normalize_time_domain(signal): + signal = signal * (1.0 / (torch.max(torch.abs(signal), dim=1, keepdim=True)[0] + 1e-5)) + return signal diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/jasper.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/jasper.py new file mode 100644 index 0000000000000000000000000000000000000000..6beec5faf90b3dca8f27d0ea9939600e93ba58c6 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/jasper.py @@ -0,0 +1,557 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +from nemo.collections.asr.parts.activations import Swish + +jasper_activations = {"hardtanh": nn.Hardtanh, "relu": nn.ReLU, "selu": nn.SELU, "swish": Swish} + + +def init_weights(m, mode: Optional[str] = 'xavier_uniform'): + if isinstance(m, MaskedConv1d): + init_weights(m.conv, mode) + if isinstance(m, (nn.Conv1d, nn.Linear)): + if mode is not None: + if mode == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight, gain=1.0) + elif mode == 'xavier_normal': + nn.init.xavier_normal_(m.weight, gain=1.0) + elif mode == 'kaiming_uniform': + nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") + elif mode == 'kaiming_normal': + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + else: + raise ValueError("Unknown Initialization mode: {0}".format(mode)) + elif isinstance(m, nn.BatchNorm1d): + if m.track_running_stats: + m.running_mean.zero_() + m.running_var.fill_(1) + m.num_batches_tracked.zero_() + if m.affine: + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + +def compute_new_kernel_size(kernel_size, kernel_width): + new_kernel_size = max(int(kernel_size * kernel_width), 1) + # If kernel is even shape, round up to make it odd + if new_kernel_size % 2 == 0: + new_kernel_size += 1 + return new_kernel_size + + +def get_same_padding(kernel_size, stride, dilation): + if stride > 1 and dilation > 1: + raise ValueError("Only stride OR dilation may be greater than 1") + if dilation > 1: + return (dilation * kernel_size) // 2 - 1 + return kernel_size // 2 + + +class StatsPoolLayer(nn.Module): + def __init__(self, feat_in, pool_mode='xvector'): + super().__init__() + self.feat_in = 0 + if pool_mode == 'gram': + gram = True + super_vector = False + elif pool_mode == 'superVector': + gram = True + super_vector = True + else: + gram = False + super_vector = False + + if gram: + self.feat_in += feat_in ** 2 + else: + self.feat_in += 2 * feat_in + + if super_vector and gram: + self.feat_in += 2 * feat_in + + self.gram = gram + self.super = super_vector + + def forward(self, encoder_output): + + mean = encoder_output.mean(dim=-1) # Time Axis + std = encoder_output.std(dim=-1) + + pooled = torch.cat([mean, std], dim=-1) + + if self.gram: + time_len = encoder_output.shape[-1] + # encoder_output = encoder_output + cov = encoder_output.bmm(encoder_output.transpose(2, 1)) # cov matrix + cov = cov.view(cov.shape[0], -1) / time_len + + if self.gram and not self.super: + return cov + + if self.super and self.gram: + pooled = torch.cat([pooled, cov], dim=-1) + + return pooled + + +class MaskedConv1d(nn.Module): + __constants__ = ["use_conv_mask", "real_out_channels", "heads"] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + heads=-1, + bias=False, + use_mask=True, + ): + super(MaskedConv1d, self).__init__() + + if not (heads == -1 or groups == in_channels): + raise ValueError("Only use heads for depthwise convolutions") + + self.real_out_channels = out_channels + if heads != -1: + in_channels = heads + out_channels = heads + groups = heads + + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.use_mask = use_mask + self.heads = heads + + def get_seq_len(self, lens): + return ( + lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1 + ) // self.conv.stride[0] + 1 + + def forward(self, x, lens): + if self.use_mask: + lens = lens.to(dtype=torch.long) + max_len = x.size(2) + mask = torch.arange(max_len).to(lens.device).expand(len(lens), max_len) >= lens.unsqueeze(1) + x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0) + # del mask + lens = self.get_seq_len(lens) + + sh = x.shape + if self.heads != -1: + x = x.view(-1, self.heads, sh[-1]) + + out = self.conv(x) + + if self.heads != -1: + out = out.view(sh[0], self.real_out_channels, -1) + + return out, lens + + +class GroupShuffle(nn.Module): + def __init__(self, groups, channels): + super(GroupShuffle, self).__init__() + + self.groups = groups + self.channels_per_group = channels // groups + + def forward(self, x): + sh = x.shape + + x = x.view(-1, self.groups, self.channels_per_group, sh[-1]) + + x = torch.transpose(x, 1, 2).contiguous() + + x = x.view(-1, self.groups * self.channels_per_group, sh[-1]) + + return x + + +class SqueezeExcite(nn.Module): + def __init__( + self, + channels: int, + reduction_ratio: int, + context_window: int = -1, + interpolation_mode: str = 'nearest', + activation: Optional[Callable] = None, + ): + """ + Squeeze-and-Excitation sub-module. + + Args: + channels: Input number of channels. + reduction_ratio: Reduction ratio for "squeeze" layer. + context_window: Integer number of timesteps that the context + should be computed over, using stride 1 average pooling. + If value < 1, then global context is computed. + interpolation_mode: Interpolation mode of timestep dimension. + Used only if context window is > 1. + The modes available for resizing are: `nearest`, `linear` (3D-only), + `bilinear`, `area` + activation: Intermediate activation function used. Must be a + callable activation function. + """ + super(SqueezeExcite, self).__init__() + self.context_window = int(context_window) + self.interpolation_mode = interpolation_mode + + if self.context_window <= 0: + self.pool = nn.AdaptiveAvgPool1d(1) # context window = T + else: + self.pool = nn.AvgPool1d(self.context_window, stride=1) + + if activation is None: + activation = nn.ReLU(inplace=True) + + self.fc = nn.Sequential( + nn.Linear(channels, channels // reduction_ratio, bias=False), + activation, + nn.Linear(channels // reduction_ratio, channels, bias=False), + ) + + def forward(self, x): + # The use of negative indices on the transpose allow for expanded SqueezeExcite + batch, channels, timesteps = x.size()[:3] + y = self.pool(x) # [B, C, T - context_window + 1] + y = y.transpose(1, -1) # [B, T - context_window + 1, C] + y = self.fc(y) # [B, T - context_window + 1, C] + y = y.transpose(1, -1) # [B, C, T - context_window + 1] + + if self.context_window > 0: + y = torch.nn.functional.interpolate(y, size=timesteps, mode=self.interpolation_mode) + + y = torch.sigmoid(y) + + return x * y + + +class JasperBlock(nn.Module): + __constants__ = ["conv_mask", "separable", "residual_mode", "res", "mconv"] + + def __init__( + self, + inplanes, + planes, + repeat=3, + kernel_size=11, + kernel_size_factor=1, + stride=1, + dilation=1, + padding='same', + dropout=0.2, + activation=None, + residual=True, + groups=1, + separable=False, + heads=-1, + normalization="batch", + norm_groups=1, + residual_mode='add', + residual_panes=[], + conv_mask=False, + se=False, + se_reduction_ratio=16, + se_context_window=None, + se_interpolation_mode='nearest', + stride_last=False, + ): + super(JasperBlock, self).__init__() + + if padding != "same": + raise ValueError("currently only 'same' padding is supported") + + kernel_size_factor = float(kernel_size_factor) + if type(kernel_size) in (list, tuple): + kernel_size = [compute_new_kernel_size(k, kernel_size_factor) for k in kernel_size] + else: + kernel_size = compute_new_kernel_size(kernel_size, kernel_size_factor) + + padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0]) + self.conv_mask = conv_mask + self.separable = separable + self.residual_mode = residual_mode + self.se = se + + inplanes_loop = inplanes + conv = nn.ModuleList() + + for _ in range(repeat - 1): + # Stride last means only the last convolution in block will have stride + if stride_last: + stride_val = [1] + else: + stride_val = stride + + conv.extend( + self._get_conv_bn_layer( + inplanes_loop, + planes, + kernel_size=kernel_size, + stride=stride_val, + dilation=dilation, + padding=padding_val, + groups=groups, + heads=heads, + separable=separable, + normalization=normalization, + norm_groups=norm_groups, + ) + ) + + conv.extend(self._get_act_dropout_layer(drop_prob=dropout, activation=activation)) + + inplanes_loop = planes + + conv.extend( + self._get_conv_bn_layer( + inplanes_loop, + planes, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding_val, + groups=groups, + heads=heads, + separable=separable, + normalization=normalization, + norm_groups=norm_groups, + ) + ) + + if se: + conv.append( + SqueezeExcite( + planes, + reduction_ratio=se_reduction_ratio, + context_window=se_context_window, + interpolation_mode=se_interpolation_mode, + activation=activation, + ) + ) + + self.mconv = conv + + res_panes = residual_panes.copy() + self.dense_residual = residual + + if residual: + res_list = nn.ModuleList() + + if residual_mode == 'stride_add': + stride_val = stride + else: + stride_val = [1] + + if len(residual_panes) == 0: + res_panes = [inplanes] + self.dense_residual = False + for ip in res_panes: + res = nn.ModuleList( + self._get_conv_bn_layer( + ip, + planes, + kernel_size=1, + normalization=normalization, + norm_groups=norm_groups, + stride=stride_val, + ) + ) + + res_list.append(res) + + self.res = res_list + else: + self.res = None + + self.mout = nn.Sequential(*self._get_act_dropout_layer(drop_prob=dropout, activation=activation)) + + def _get_conv( + self, + in_channels, + out_channels, + kernel_size=11, + stride=1, + dilation=1, + padding=0, + bias=False, + groups=1, + heads=-1, + separable=False, + ): + use_mask = self.conv_mask + if use_mask: + return MaskedConv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias, + groups=groups, + heads=heads, + use_mask=use_mask, + ) + else: + return nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias, + groups=groups, + ) + + def _get_conv_bn_layer( + self, + in_channels, + out_channels, + kernel_size=11, + stride=1, + dilation=1, + padding=0, + bias=False, + groups=1, + heads=-1, + separable=False, + normalization="batch", + norm_groups=1, + ): + if norm_groups == -1: + norm_groups = out_channels + + if separable: + layers = [ + self._get_conv( + in_channels, + in_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias, + groups=in_channels, + heads=heads, + ), + self._get_conv( + in_channels, + out_channels, + kernel_size=1, + stride=1, + dilation=1, + padding=0, + bias=bias, + groups=groups, + ), + ] + else: + layers = [ + self._get_conv( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias, + groups=groups, + ) + ] + + if normalization == "group": + layers.append(nn.GroupNorm(num_groups=norm_groups, num_channels=out_channels)) + elif normalization == "instance": + layers.append(nn.GroupNorm(num_groups=out_channels, num_channels=out_channels)) + elif normalization == "layer": + layers.append(nn.GroupNorm(num_groups=1, num_channels=out_channels)) + elif normalization == "batch": + layers.append(nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.1)) + else: + raise ValueError( + f"Normalization method ({normalization}) does not match" f" one of [batch, layer, group, instance]." + ) + + if groups > 1: + layers.append(GroupShuffle(groups, out_channels)) + return layers + + def _get_act_dropout_layer(self, drop_prob=0.2, activation=None): + if activation is None: + activation = nn.Hardtanh(min_val=0.0, max_val=20.0) + layers = [activation, nn.Dropout(p=drop_prob)] + return layers + + def forward(self, input_: Tuple[List[Tensor], Optional[Tensor]]): + # type: (Tuple[List[Tensor], Optional[Tensor]]) -> Tuple[List[Tensor], Optional[Tensor]] # nopep8 + lens_orig = None + xs = input_[0] + if len(input_) == 2: + xs, lens_orig = input_ + + # compute forward convolutions + out = xs[-1] + + lens = lens_orig + for i, l in enumerate(self.mconv): + # if we're doing masked convolutions, we need to pass in and + # possibly update the sequence lengths + # if (i % 4) == 0 and self.conv_mask: + if isinstance(l, MaskedConv1d): + out, lens = l(out, lens) + else: + out = l(out) + + # compute the residuals + if self.res is not None: + for i, layer in enumerate(self.res): + res_out = xs[i] + for j, res_layer in enumerate(layer): + if isinstance(res_layer, MaskedConv1d): + res_out, _ = res_layer(res_out, lens_orig) + else: + res_out = res_layer(res_out) + + if self.residual_mode == 'add' or self.residual_mode == 'stride_add': + out = out + res_out + else: + out = torch.max(out, res_out) + + # compute the output + out = self.mout(out) + if self.res is not None and self.dense_residual: + return xs + [out], lens + + return [out], lens diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/layer_norm.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..926fb9ea1ff6c9d3c38e31baa0456fba03f69cba --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/layer_norm.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + + +except ImportError: + has_fused_layernorm = False + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if torch.jit.is_scripting(): + export = True + # if not export and torch.cuda.is_available() and has_fused_layernorm: + # return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/manifest.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..21e32bad43f836f9a412a66d010b48d7f88ef1b3 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/manifest.py @@ -0,0 +1,144 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from os.path import expanduser +from typing import Any, Callable, Dict, Iterator, List, Optional, Union + +import pandas + + +class ManifestBase: + def __init__(self, *args, **kwargs): + raise ValueError( + "This class is deprecated, look at https://github.com/NVIDIA/NeMo/pull/284 for correct behaviour." + ) + + +class ManifestEN: + def __init__(self, *args, **kwargs): + raise ValueError( + "This class is deprecated, look at https://github.com/NVIDIA/NeMo/pull/284 for correct behaviour." + ) + + +def item_iter( + manifests_files: Union[str, List[str]], parse_func: Callable[[str, Optional[str]], Dict[str, Any]] = None +) -> Iterator[Dict[str, Any]]: + """Iterate through json lines of provided manifests. + + NeMo ASR pipelines often assume certain manifest files structure. In + particular, each manifest file should consist of line-per-sample files with + each line being correct json dict. Each such json dict should have a field + for audio file string, a field for duration float and a field for text + string. Offset also could be additional field and is set to None by + default. + + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + + parse_func: A callable function which accepts as input a single line + of a manifest and optionally the manifest file itself, + and parses it, returning a dictionary mapping from str -> Any. + + Yields: + Parsed key to value item dicts. + + Raises: + ValueError: If met invalid json line structure. + """ + + if isinstance(manifests_files, str): + manifests_files = [manifests_files] + + if parse_func is None: + parse_func = __parse_item + + k = -1 + for manifest_file in manifests_files: + if not manifest_file.endswith('.csv'): + with open(expanduser(manifest_file), 'r') as f: + for line in f: + k += 1 + item = parse_func(line, manifest_file) + item['id'] = k + + yield item + else: + df = pandas.read_csv(manifest_file, encoding='utf-8') + for row in df.itertuples(index=False): + k += 1 + item = __parse_csv_item(row, manifest_file) + item['id'] = k + + yield item + + +def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: + item = json.loads(line) + + # Audio file + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + elif 'audio' in item: + item['audio_file'] = item.pop('audio') + else: + raise ValueError( + f"Manifest file {manifest_file} has invalid json line structure: {line} without proper audio file key." + ) + item['audio_file'] = expanduser(item['audio_file']) + + # Duration. + if 'duration' not in item: + raise ValueError( + f"Manifest file {manifest_file} has invalid json line structure: {line} without proper duration key." + ) + + # Text. + if 'text' in item: + pass + elif 'text_filepath' in item: + with open(item.pop('text_filepath'), 'r') as f: + item['text'] = f.read().replace('\n', '') + elif 'normalized_text' in item: + item['text'] = item['normalized_text'] + else: + item['text'] = None + + item = dict( + audio_file=item['audio_file'], + duration=item['duration'], + text=item['text'], + offset=item.get('offset', None), + speaker=item.get('speaker', None), + orig_sr=item.get('orig_sample_rate', None), + ) + + return item + + +def __parse_csv_item(row, manifest_file: str) -> Dict[str, Any]: + item = {'audio_file': row.wav_filename, + 'duration': (int(row.wav_filesize) - 44) / (16000 * 2), + 'offset': None, + 'speaker': None, + 'orig_sr': None} + + if hasattr(row, 'transcript'): + item['text'] = row.transcript + + return item diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/mixins.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..7111c17db1910b32b8370fd3f806fff9f2fe4fb1 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/mixins.py @@ -0,0 +1,111 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC + +from omegaconf import DictConfig, OmegaConf + +from nemo.collections.common import tokenizers +from nemo.utils import logging + + +class ASRBPEMixin(ABC): + """ ASR BPE Mixin class that sets up a Tokenizer via a config + + This mixin class adds the method `_setup_tokenizer(...)`, which can be used by ASR models + which depend on subword tokenization. + + The setup_tokenizer method adds the following parameters to the class - + - tokenizer_cfg: The resolved config supplied to the tokenizer (with `dir` and `type` arguments). + - tokenizer_dir: The directory path to the tokenizer vocabulary + additional metadata. + - tokenizer_type: The type of the tokenizer. Currently supports `bpe` and `wpe`. + - vocab_path: Resolved path to the vocabulary text file. + + In addition to these variables, the method will also instantiate and preserve a tokenizer + (subclass of TokenizerSpec) if successful, and assign it to self.tokenizer. + """ + + def _setup_tokenizer(self, tokenizer_cfg: DictConfig, register_artifact=True): + # Prevent tokenizer parallelism (unless user has explicitly set it) + if 'TOKENIZERS_PARALLELISM' not in os.environ: + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + self.tokenizer_cfg = OmegaConf.to_container(tokenizer_cfg, resolve=True) # type: dict + self.tokenizer_dir = self.tokenizer_cfg.pop('dir') # Remove tokenizer directory + self.tokenizer_file = self.tokenizer_cfg.pop('file') + self.tokenizer_type = self.tokenizer_cfg.pop('type').lower() # Remove tokenizer_type + + if self.tokenizer_type not in ['bpe', 'wpe']: + raise ValueError( + "`tokenizer.type` must be either `bpe` for SentencePiece tokenizer or " + "`wpe` for BERT based tokenizer" + ) + + if self.tokenizer_type == 'bpe': + # This is a BPE Tokenizer + tokenizer_file = self.tokenizer_file + '.model' if self.tokenizer_file else 'tokenizer.model' + model_path = os.path.join(self.tokenizer_dir, tokenizer_file) + if register_artifact: + model_path = self.register_artifact('tokenizer.model_path', model_path) + self.model_path = model_path + + if 'special_tokens' in self.tokenizer_cfg: + special_tokens = self.tokenizer_cfg['special_tokens'] + else: + special_tokens = None + + # Update special tokens + self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path, special_tokens=special_tokens) + + vocab_file = self.tokenizer_file + '.vocab' if self.tokenizer_file else 'vocab.txt' + vocab_path = os.path.join(self.tokenizer_dir, vocab_file) + if register_artifact: + vocab_path = self.register_artifact('tokenizer.vocab_path', vocab_path) + self.vocab_path = vocab_path + + if self.tokenizer_cfg.get('prepend_unk_to_vocab', True): + vocabulary = {0: ''} + else: + vocabulary = {} + with open(vocab_path, encoding='utf-8') as f: + for i, piece in enumerate(f): + piece = piece.replace('\n', '') + vocabulary[i + 1] = piece + + # wrapper method to get vocabulary conveniently + def get_vocab(): + return vocabulary + + # attach utility values to the tokenizer wrapper + self.tokenizer.tokenizer.vocab_size = len(vocabulary) + self.tokenizer.tokenizer.get_vocab = get_vocab + self.tokenizer.tokenizer.all_special_tokens = self.tokenizer.special_token_to_id + + else: + # This is a WPE Tokenizer + vocab_path = os.path.join(self.tokenizer_dir, 'vocab.txt') + if register_artifact: + self.tokenizer_dir = self.register_artifact('tokenizer.vocab_path', vocab_path) + self.vocab_path = self.tokenizer_dir + + self.tokenizer = tokenizers.AutoTokenizer( + pretrained_model_name='bert-base-cased', vocab_file=self.tokenizer_dir, **self.tokenizer_cfg + ) + + logging.info( + "Tokenizer {} initialized with {} tokens".format( + self.tokenizer.__class__.__name__, self.tokenizer.vocab_size + ) + ) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/multi_head_attention.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/multi_head_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..bcaa246175bb6d063f42fb39c4bdc874c46d0e76 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/multi_head_attention.py @@ -0,0 +1,305 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Part of this code is adopted from https://github.com/espnet/espnet +""" + +import math + +import numpy as np +import torch +import torch.nn as nn + +__all__ = [ + 'RelPositionMultiHeadAttention', + 'RelPositionalEncoding', + 'PositionalEncoding', +] + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention layer. + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transforms query, key and value. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value (torch.Tensor): (batch, time2, size) + returns: + q (torch.Tensor): (batch, head, time1, size) + k (torch.Tensor): (batch, head, time2, size) + v (torch.Tensor): (batch, head, time2, size) + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + Args: + value (torch.Tensor): (batch, time2, size) + scores(torch.Tensor): (batch, time1, time2) + mask(torch.Tensor): (batch, time1, time2) + returns: + value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1) # (batch, 1, time1, time2) + if scores.dtype == torch.float16: + dtype = np.float16 + else: + dtype = np.float32 + min_value = np.finfo(dtype).min + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask, pos_emb=None): + """Compute 'Scaled Dot Product Attention'. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadAttention(MultiHeadAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate + """ + + def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + if pos_bias_u is None or pos_bias_v is None: + self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) + # nn.init.normal_(self.pos_bias_u, 0.0, 0.02) + # nn.init.normal_(self.pos_bias_v, 0.0, 0.02) + nn.init.zeros_(self.pos_bias_u) + nn.init.zeros_(self.pos_bias_v) + + else: + self.pos_bias_u = pos_bias_u + self.pos_bias_v = pos_bias_v + + def rel_shift(self, x): + """Compute relative positional encoding. + Args: + x (torch.Tensor): (batch, nheads, time, 2*time-1) + """ + qlen = x.size(2) + pos_len = x.size(-1) + x = x.view(x.size(0), x.size(1), -1) + x = torch.nn.functional.pad(x, pad=(0, qlen)) + x = x.view(x.size(0), x.size(1), qlen, pos_len + 1) + return x[:, :, :, 0:qlen].flip(dims=[-1]) + + def forward(self, query, key, value, mask, pos_emb): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + pos_emb (torch.Tensor) : (batch, time1, size) + Returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + Args: + d_model (int): embedding dim + dropout_rate (float): dropout rate + max_len (int): maximum input length + reverse (int): whether to reverse the input position + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False, xscale=None): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = xscale + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.reverse: + needed_size = 2 * x.size(1) - 1 + else: + needed_size = x.size(1) + if self.pe is not None: + if self.pe.size(1) >= needed_size: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(needed_size, self.d_model) + if self.reverse: + position = torch.arange(-(x.size(1) - 1), x.size(1), 1.0, dtype=torch.float32).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + Returns: + Encoded Output (torch.Tensor): Its shape is (batch, time, ...) + """ + self.extend_pe(x) + if self.xscale: + x = x * self.xscale + x = x + self.pe[:, : x.size(1)] + return self.dropout(x), None + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): embedding dim + dropout_rate (float): dropout rate + max_len (int): maximum input length + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, xscale=None, dropout_emb_rate=0.0): + super().__init__(d_model, dropout_rate, max_len, reverse=True, xscale=xscale) + + if dropout_emb_rate > 0: + self.dropout_emb = nn.Dropout(dropout_emb_rate) + else: + self.dropout_emb = None + + self.max_len = max_len + + def forward(self, x): + """Compute positional encoding. + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + Returns: + x (torch.Tensor): Its shape is (batch, time, ...) + pos_emb (torch.Tensor): Its shape is (1, time, ...) + """ + self.extend_pe(x) + if self.xscale: + x = x * self.xscale + + start_pos = (self.pe.size(1) + 1) // 2 - x.size(1) + pos_emb = self.pe[:, start_pos:-start_pos] + if self.dropout_emb: + pos_emb = self.dropout_emb(pos_emb) + return self.dropout(x), pos_emb diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/multihead_attention.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e43a359affeff3c0b7b44020a5ec7fea0765ba4e --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/multihead_attention.py @@ -0,0 +1,465 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = torch.nn.Dropout( + dropout + ) + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) + + self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) + + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + ): + assert key is not None and value is not None + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention and input_buffer_k.size( + 0 + ) == new_order.size(0): + break + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a23cf78c022fac041754632f0dfee3b7defa82 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import RNNTLossNumba diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77c5b60fdfb38bca6a5de1888885abfbc639ee25 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.parts.numba.rnnt_loss.rnnt import rnnt_loss_cpu, rnnt_loss_gpu +from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import RNNTLossNumba diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py new file mode 100644 index 0000000000000000000000000000000000000000..4acdb680cd13649c8210702ed066218cd01d9d50 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -0,0 +1,233 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright 2018-2019, Mingkun Huang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing + +import torch +from numba import cuda + +from nemo.collections.asr.parts.numba.rnnt_loss.utils import global_constants, rnnt_helper +from nemo.collections.asr.parts.numba.rnnt_loss.utils.cpu_utils import cpu_rnnt +from nemo.collections.asr.parts.numba.rnnt_loss.utils.cuda_utils import gpu_rnnt + + +def rnnt_loss_cpu( + acts: torch.Tensor, + labels: torch.Tensor, + input_lengths: torch.Tensor, + label_lengths: torch.Tensor, + costs: torch.Tensor, + grads: torch.Tensor, + blank_label: int, + fastemit_lambda: float, + num_threads: int, +): + """ + Wrapper method for accessing CPU RNNT loss. + + CPU implementation ported from [HawkAaron/warp-transducer](https://github.com/HawkAaron/warp-transducer). + + Args: + acts: Activation tensor of shape [B, T, U, V+1]. + labels: Ground truth labels of shape [B, U]. + input_lengths: Lengths of the acoustic sequence as a vector of ints [B]. + label_lengths: Lengths of the target sequence as a vector of ints [B]. + costs: Zero vector of length [B] in which costs will be set. + grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set. + blank_label: Index of the blank token in the vocabulary. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + num_threads: Number of threads for OpenMP. + """ + # aliases + log_probs = acts + flat_labels = labels + + minibatch_size = log_probs.shape[0] + maxT = log_probs.shape[1] + maxU = log_probs.shape[2] + alphabet_size = log_probs.shape[3] + + if num_threads < 0: + num_threads = multiprocessing.cpu_count() + + num_threads = max(1, num_threads) # have to use at least 1 thread + + gpu_size, status = rnnt_helper.get_workspace_size(maxT, maxU, minibatch_size, gpu=False) + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Invalid parameter passed when calculating working space memory") + + cpu_workspace = torch.zeros(gpu_size, device=log_probs.device, dtype=log_probs.dtype, requires_grad=False) + + ### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ### + log_probs, acts_shape = rnnt_helper.flatten_tensor(log_probs) + flat_labels, labels_shape = rnnt_helper.flatten_tensor(flat_labels) + + wrapper = cpu_rnnt.CPURNNT( + minibatch=minibatch_size, + maxT=maxT, + maxU=maxU, + alphabet_size=alphabet_size, + workspace=cpu_workspace, + blank=blank_label, + fastemit_lambda=fastemit_lambda, + num_threads=num_threads, + batch_first=True, + ) + + if grads is None: + status = wrapper.score_forward( + log_probs=log_probs.data, + costs=costs, + flat_labels=flat_labels.data, + label_lengths=label_lengths.data, + input_lengths=input_lengths.data, + ) + + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Could not calculate forward scores") + + else: + ### FLATTEN GRAD TENSOR ### + grads, grads_shape = rnnt_helper.flatten_tensor(grads) + + status = wrapper.cost_and_grad( + log_probs=log_probs.data, + grads=grads.data, + costs=costs, + flat_labels=flat_labels.data, + label_lengths=label_lengths.data, + input_lengths=input_lengths.data, + ) + + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Could not calculate forward scores") + + del cpu_workspace, wrapper + return True + + +def rnnt_loss_gpu( + acts: torch.Tensor, + labels: torch.Tensor, + input_lengths: torch.Tensor, + label_lengths: torch.Tensor, + costs: torch.Tensor, + grads: torch.Tensor, + blank_label: int, + fastemit_lambda: float, + num_threads: int, +): + """ + Wrapper method for accessing GPU RNNT loss. + + CUDA implementation ported from [HawkAaron/warp-transducer](https://github.com/HawkAaron/warp-transducer). + + Args: + acts: Activation tensor of shape [B, T, U, V+1]. + labels: Ground truth labels of shape [B, U]. + input_lengths: Lengths of the acoustic sequence as a vector of ints [B]. + label_lengths: Lengths of the target sequence as a vector of ints [B]. + costs: Zero vector of length [B] in which costs will be set. + grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set. + blank_label: Index of the blank token in the vocabulary. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + num_threads: Number of threads for OpenMP. + """ + minibatch_size = acts.shape[0] + maxT = acts.shape[1] + maxU = acts.shape[2] + alphabet_size = acts.shape[3] + + if hasattr(cuda, 'external_stream'): + stream = cuda.external_stream(torch.cuda.current_stream(acts.device).cuda_stream) + else: + stream = cuda.default_stream() + + if num_threads < 0: + num_threads = multiprocessing.cpu_count() + + num_threads = max(1, num_threads) # have to use at least 1 thread + + gpu_size, status = rnnt_helper.get_workspace_size(maxT, maxU, minibatch_size, gpu=True) + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Invalid parameter passed when calculating working space memory") + + # Select GPU index + cuda.select_device(acts.device.index) + gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False) + + ### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ### + acts, acts_shape = rnnt_helper.flatten_tensor(acts) + + ### REPRESENT THE CUDA ARRAY INTERFACE OF COSTS VECTOR ### + costs_repr = cuda.as_cuda_array(costs, sync=False) # NO COPY OF DATA, JUST CHANGE REPRESENTATION + + wrapper = gpu_rnnt.GPURNNT( + minibatch=minibatch_size, + maxT=maxT, + maxU=maxU, + alphabet_size=alphabet_size, + workspace=gpu_workspace, + blank=blank_label, + fastemit_lambda=fastemit_lambda, + num_threads=num_threads, + stream=stream, + ) + + if grads is None: + status = wrapper.score_forward( + acts=acts.data, + costs=costs_repr, + pad_labels=labels.data, + label_lengths=label_lengths.data, + input_lengths=input_lengths.data, + ) + + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Could not calculate forward scores") + + else: + ### FLATTEN GRAD TENSOR ### + grads, grads_shape = rnnt_helper.flatten_tensor(grads) + + status = wrapper.cost_and_grad( + acts=acts.data, + grads=grads.data, + costs=costs_repr, + pad_labels=labels.data, + label_lengths=label_lengths.data, + input_lengths=input_lengths.data, + ) + + if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: + raise RuntimeError("Could not calculate forward scores") + + del gpu_workspace, wrapper + return True diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..8a47b1a4041d7e6d082433e91d3935c95f8c494b --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -0,0 +1,340 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright 2018-2019, Mingkun Huang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch +from torch.autograd import Function, Variable +from torch.nn import Module + + +def check_type(var, t, name): + if var.dtype is not t: + raise TypeError("{} must be {}".format(name, t)) + + +def check_contiguous(var, name): + if not var.is_contiguous(): + raise ValueError("{} must be contiguous".format(name)) + + +def check_dim(var, dim, name): + if len(var.shape) != dim: + raise ValueError("{} must be {}D".format(name, dim)) + + +def certify_inputs(log_probs, labels, lengths, label_lengths): + # check_type(log_probs, torch.float32, "log_probs") + check_type(labels, torch.int32, "labels") + check_type(label_lengths, torch.int32, "label_lengths") + check_type(lengths, torch.int32, "lengths") + check_contiguous(log_probs, "log_probs") + check_contiguous(labels, "labels") + check_contiguous(label_lengths, "label_lengths") + check_contiguous(lengths, "lengths") + + if lengths.shape[0] != log_probs.shape[0]: + raise ValueError( + f"Must have a length per example. " + f"Given lengths dim: {lengths.shape[0]}, " + f"Log probs dim : {log_probs.shape[0]}" + ) + if label_lengths.shape[0] != log_probs.shape[0]: + raise ValueError( + "Must have a label length per example. " + f"Given label lengths dim : {label_lengths.shape[0]}, " + f"Log probs dim : {log_probs.shape[0]}" + ) + + check_dim(log_probs, 4, "log_probs") + check_dim(labels, 2, "labels") + check_dim(lengths, 1, "lenghts") + check_dim(label_lengths, 1, "label_lenghts") + max_T = torch.max(lengths) + max_U = torch.max(label_lengths) + T, U = log_probs.shape[1:3] + if T != max_T: + raise ValueError(f"Input length mismatch! Given T: {T}, Expected max T from input lengths: {max_T}") + if U != max_U + 1: + raise ValueError(f"Output length mismatch! Given U: {U}, Expected max U from target lengths: {max_U} + 1") + + +def _assert_no_grad(tensor): + assert not tensor.requires_grad, ( + "gradients only computed for log_probs - please " "mark other tensors as not requiring gradients" + ) + + +def forward_pass(log_probs, labels, blank): + """ + Computes probability of the forward variable alpha. + + Args: + log_probs: Tensor of shape [T, U, V+1] + labels: Labels of shape [B, U] + blank: Index of the blank token. + + Returns: + A tuple of the forward variable probabilities - alpha of shape [T, U] + and the log likelihood of this forward step. + """ + T, U, _ = log_probs.shape + alphas = np.zeros((T, U), dtype='f') + + for t in range(1, T): + alphas[t, 0] = alphas[t - 1, 0] + log_probs[t - 1, 0, blank] + + for u in range(1, U): + alphas[0, u] = alphas[0, u - 1] + log_probs[0, u - 1, labels[u - 1]] + for t in range(1, T): + for u in range(1, U): + no_emit = alphas[t - 1, u] + log_probs[t - 1, u, blank] + emit = alphas[t, u - 1] + log_probs[t, u - 1, labels[u - 1]] + alphas[t, u] = np.logaddexp(emit, no_emit) + + loglike = alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank] + return alphas, loglike + + +def backward_pass(log_probs, labels, blank): + """ + Computes probability of the backward variable beta. + + Args: + log_probs: Tensor of shape [T, U, V+1] + labels: Labels of shape [B, U] + blank: Index of the blank token. + + Returns: + A tuple of the backward variable probabilities - beta of shape [T, U] + and the log likelihood of this backward step. + """ + T, U, _ = log_probs.shape + betas = np.zeros((T, U), dtype='f') + betas[T - 1, U - 1] = log_probs[T - 1, U - 1, blank] + + for t in reversed(range(T - 1)): + betas[t, U - 1] = betas[t + 1, U - 1] + log_probs[t, U - 1, blank] + + for u in reversed(range(U - 1)): + betas[T - 1, u] = betas[T - 1, u + 1] + log_probs[T - 1, u, labels[u]] + + for t in reversed(range(T - 1)): + for u in reversed(range(U - 1)): + no_emit = betas[t + 1, u] + log_probs[t, u, blank] + emit = betas[t, u + 1] + log_probs[t, u, labels[u]] + betas[t, u] = np.logaddexp(emit, no_emit) + + return betas, betas[0, 0] + + +def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda): + """ + Computes the gradients of the log_probs with respect to the log probability of this step occuring. + + Args: + Args: + log_probs: Tensor of shape [T, U, V+1] + alphas: Tensor of shape [T, U] which represents the forward variable. + betas: Tensor of shape [T, U] which represents the backward variable. + labels: Labels of shape [B, U] + blank: Index of the blank token. + + Returns: + Gradients of shape [T, U, V+1] with respect to the forward log probability + """ + T, U, _ = log_probs.shape + grads = np.full(log_probs.shape, -float("inf")) + log_like = betas[0, 0] # == alphas[T - 1, U - 1] + betas[T - 1, U - 1] + + # // grad to last blank transition + grads[T - 1, U - 1, blank] = alphas[T - 1, U - 1] + grads[: T - 1, :, blank] = alphas[: T - 1, :] + betas[1:, :] + + # // grad to label transition + for u, l in enumerate(labels): + grads[:, u, l] = alphas[:, u] + betas[:, u + 1] + + grads = -np.exp(grads + log_probs - log_like) + + if fastemit_lambda > 0.0: + for u, l in enumerate(labels): + grads[:, u, l] = (1.0 + fastemit_lambda) * grads[:, u, l] + + return grads + + +def fastemit_regularization(log_probs, labels, alphas, betas, blank, fastemit_lambda): + """ + Describes the computation of FastEmit regularization from the paper - + [FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization](https://arxiv.org/abs/2010.11148) + + Args: + log_probs: Tensor of shape [T, U, V+1] + labels: Unused. Labels of shape [B, U] + alphas: Tensor of shape [T, U] which represents the forward variable. + betas: Unused. Tensor of shape [T, U] which represents the backward variable. + blank: Index of the blank token. + fastemit_lambda: Float scaling factor for FastEmit regularization. + + Returns: + The regularized negative log likelihood - lambda * P˜(At, u|x) + """ + # General calculation of the fastemit regularization alignments + T, U, _ = log_probs.shape + # alignment = np.zeros((T, U), dtype='float32') + # + # for t in range(0, T): + # alignment[t, U - 1] = alphas[t, U - 1] + betas[t, U - 1] + # + # for t in range(0, T): + # for u in range(0, U - 1): + # emit = alphas[t, u] + log_probs[t, u, labels[u]] + betas[t, u + 1] + # alignment[t, u] = emit + # reg = fastemit_lambda * (alignment[T - 1, U - 1]) + + # The above is equivalent to below, without need of computing above + # reg = fastemit_lambda * (alphas[T - 1, U - 1] + betas[T - 1, U - 1]) + + # The above is also equivalent to below, without need of computing the betas alignment matrix + reg = fastemit_lambda * (alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank]) + return -reg + + +def transduce(log_probs, labels, blank=0, fastemit_lambda=0.0): + """ + Args: + log_probs: 3D array with shape + [input len, output len + 1, vocab size] + labels: 1D array with shape [output time steps] + blank: Index of the blank token. + fastemit_lambda: Float scaling factor for FastEmit regularization. + + Returns: + float: The negative log-likelihood + 3D array: Gradients with respect to the + unnormalized input actications + 2d arrays: Alphas matrix (TxU) + 2d array: Betas matrix (TxU) + """ + alphas, ll_forward = forward_pass(log_probs, labels, blank) + betas, ll_backward = backward_pass(log_probs, labels, blank) + grads = compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda) + return -ll_forward, grads, alphas, betas + + +def transduce_batch(log_probs, labels, flen, glen, blank=0, fastemit_lambda=0.0): + """ + Compute the transducer loss of the batch. + + Args: + log_probs: [B, T, U, V+1]. Activation matrix normalized with log-softmax. + labels: [B, U+1] - ground truth labels with padded as blank token in the beginning. + flen: Length vector of the acoustic sequence. + glen: Length vector of the target sequence. + blank: Id of the blank token. + fastemit_lambda: Float scaling factor for FastEmit regularization. + + Returns: + Batch of transducer forward log probabilities (loss) and the gradients of the activation matrix. + """ + grads = np.zeros_like(log_probs) + costs = [] + for b in range(log_probs.shape[0]): + t = int(flen[b]) + u = int(glen[b]) + 1 + + ll, g, alphas, betas = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank, fastemit_lambda) + grads[b, :t, :u, :] = g + + reg = fastemit_regularization( + log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda + ) + ll += reg + costs.append(ll) + return costs, grads + + +class _RNNT(Function): + @staticmethod + def forward(ctx, acts, labels, act_lens, label_lens, blank, fastemit_lambda): + costs, grads = transduce_batch( + acts.detach().cpu().numpy(), + labels.cpu().numpy(), + act_lens.cpu().numpy(), + label_lens.cpu().numpy(), + blank, + fastemit_lambda, + ) + + costs = torch.FloatTensor([sum(costs)]) + grads = torch.Tensor(grads).to(acts) + + ctx.grads = grads + return costs + + @staticmethod + def backward(ctx, grad_output): + return ctx.grads, None, None, None, None, None + + +class RNNTLoss(Module): + """ + Parameters: + `blank_label` (int): default 0 - label index of blank token + fastemit_lambda: Float scaling factor for FastEmit regularization. + """ + + def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0): + super(RNNTLoss, self).__init__() + self.blank = blank + self.fastemit_lambda = fastemit_lambda + self.rnnt = _RNNT.apply + + def forward(self, acts, labels, act_lens, label_lens): + assert len(labels.size()) == 2 + _assert_no_grad(labels) + _assert_no_grad(act_lens) + _assert_no_grad(label_lens) + certify_inputs(acts, labels, act_lens, label_lens) + + acts = torch.nn.functional.log_softmax(acts, -1) + return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda) + + +if __name__ == '__main__': + loss = RNNTLoss(fastemit_lambda=0.01) + + torch.manual_seed(0) + + acts = torch.randn(1, 2, 5, 3) + labels = torch.tensor([[0, 2, 1, 2]], dtype=torch.int32) + act_lens = torch.tensor([2], dtype=torch.int32) + label_lens = torch.tensor([len(labels[0])], dtype=torch.int32) + + loss_val = loss(acts, labels, act_lens, label_lens) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..10d9073e7c819f16cd8a52a2df9d30d309fcfaa8 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright 2018-2019, Mingkun Huang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch.autograd import Function +from torch.nn import Module + +from nemo.collections.asr.parts.numba.rnnt_loss import rnnt + +__all__ = ['rnnt_loss', 'RNNTLossNumba'] + + +class _RNNTNumba(Function): + @staticmethod + def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda): + """ + log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + """ + is_cuda = acts.is_cuda + + certify_inputs(acts, labels, act_lens, label_lens) + + loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu + grads = torch.zeros_like(acts) if acts.requires_grad else None + minibatch_size = acts.size(0) + costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype) + + loss_func( + acts, + labels=labels, + input_lengths=act_lens, + label_lengths=label_lens, + costs=costs, + grads=grads, + blank_label=blank, + fastemit_lambda=fastemit_lambda, + num_threads=0, + ) + + if reduction in ['sum', 'mean']: + costs = costs.sum().unsqueeze_(-1) + if reduction == 'mean': + costs /= minibatch_size + + if grads is not None: + grads /= minibatch_size + + ctx.grads = grads + + return costs + + @staticmethod + def backward(ctx, grad_output): + if grad_output is not None and ctx.grads is not None: + grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) + return ctx.grads.mul_(grad_output), None, None, None, None, None, None + + +def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction='mean'): + """RNN Transducer Loss + Args: + acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + blank (int, optional): blank label. Default: 0. + reduction (string, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, + 'mean': the output losses will be divided by the target lengths and + then the mean over the batch is taken. Default: 'mean' + """ + if not acts.is_cuda: + acts = torch.nn.functional.log_softmax(acts, -1) + + return _RNNTNumba.apply(acts, labels, act_lens, label_lens, blank, reduction) + + +class RNNTLossNumba(Module): + """ + Parameters: + blank (int, optional): blank label. Default: 0. + reduction (string, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, + 'mean': the output losses will be divided by the target lengths and + then the mean over the batch is taken. Default: 'mean' + """ + + def __init__(self, blank=0, reduction='mean', fastemit_lambda: float = 0.0): + super(RNNTLossNumba, self).__init__() + self.blank = blank + self.fastemit_lambda = fastemit_lambda + self.reduction = reduction + self.loss = _RNNTNumba.apply + + def forward(self, acts, labels, act_lens, label_lens): + """ + log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + """ + if not acts.is_cuda: + # NOTE manually done log_softmax for CPU version, + # log_softmax is computed within GPU version. + acts = torch.nn.functional.log_softmax(acts, -1) + + return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction, self.fastemit_lambda) + + +def check_type(var, t, name): + if var.dtype is not t: + raise TypeError("{} must be {}".format(name, t)) + + +def check_contiguous(var, name): + if not var.is_contiguous(): + raise ValueError("{} must be contiguous".format(name)) + + +def check_dim(var, dim, name): + if len(var.shape) != dim: + raise ValueError("{} must be {}D".format(name, dim)) + + +def certify_inputs(log_probs, labels, lengths, label_lengths): + # check_type(log_probs, torch.float32, "log_probs") + check_type(labels, torch.int32, "labels") + check_type(label_lengths, torch.int32, "label_lengths") + check_type(lengths, torch.int32, "lengths") + check_contiguous(log_probs, "log_probs") + check_contiguous(labels, "labels") + check_contiguous(label_lengths, "label_lengths") + check_contiguous(lengths, "lengths") + + if lengths.shape[0] != log_probs.shape[0]: + raise ValueError( + f"Must have a length per example. " + f"Given lengths dim: {lengths.shape[0]}, " + f"Log probs dim : {log_probs.shape[0]}" + ) + if label_lengths.shape[0] != log_probs.shape[0]: + raise ValueError( + "Must have a label length per example. " + f"Given label lengths dim : {label_lengths.shape[0]}, " + f"Log probs dim : {log_probs.shape[0]}" + ) + + check_dim(log_probs, 4, "log_probs") + check_dim(labels, 2, "labels") + check_dim(lengths, 1, "lenghts") + check_dim(label_lengths, 1, "label_lenghts") + max_T = torch.max(lengths) + max_U = torch.max(label_lengths) + T, U = log_probs.shape[1:3] + if T != max_T: + raise ValueError(f"Input length mismatch! Given T: {T}, Expected max T from input lengths: {max_T}") + if U != max_U + 1: + raise ValueError(f"Output length mismatch! Given U: {U}, Expected max U from target lengths: {max_U} + 1") diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/spec_augment/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/spec_augment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17a22fcf81889ff15d8cedd8c306bd0d383dcf58 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/spec_augment/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.parts.numba.spec_augment.spec_aug_numba import ( + SpecAugmentNumba, + spec_augment_launch_heuristics, +) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..01c5ea942cf12bac367595aee365ab7da2b35835 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py @@ -0,0 +1,305 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from numba import cuda + +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import LengthsType, NeuralType, SpectrogramType +from nemo.utils import logging + +MAX_THREAD_BUFFER = 512 + + +@cuda.jit() +def spec_augment_kernel( + x: torch.Tensor, + x_len: torch.Tensor, + freq_starts: torch.Tensor, + freq_widths: torch.Tensor, + time_starts: torch.Tensor, + time_widths: torch.Tensor, + mask_value: float, +): + """ + Numba CUDA kernel to perform SpecAugment in-place on the GPU. + Parallelize over freq and time axis, parallel threads over batch. + Sequential over masks (adaptive in time). + + Args: + x: Pytorch tensor of shape [B, F, T] with the acoustic features. + x_len: Pytorch tensor of shape [B] with the lengths of the padded sequence. + freq_starts: Pytorch tensor of shape [B, M_f] with the start indices of freq masks. + freq_widths: Pytorch tensor of shape [B, M_f] with the width of freq masks. + time_starts: Pytorch tensor of shape [B, M_t] with the start indices of time masks. + time_widths: Pytorch tensor of shape [B, M_t] with the width of time masks. + mask_value: Float value that will be used as mask value. + """ + f = cuda.blockIdx.x # indexes the Freq dim + t = cuda.blockIdx.y # indexes the Time dim + tid = cuda.threadIdx.x # index of the current mask + threads_per_block = cuda.blockDim.x + + # Compute the number of masks over freq axis + len_f = freq_starts.shape[1] + # For all samples in the batch, apply the freq mask + for bidx in range(0, x.shape[0], threads_per_block): + # Resolve the index of the batch (case where more masks than MAX_THREAD_BUFFER) + bm_idx = bidx + tid + + # Access mask only if valid sample id in batch + if bm_idx < x.shape[0]: + # For `len_f` number of freq masks that must be applied + for fidx in range(0, len_f): + # Access the start index and width of this freq mask + f_start = freq_starts[bm_idx, fidx] + f_width = freq_widths[bm_idx, fidx] + + # If block idx `f` >= start and < (start + width) of this freq mask + if f >= f_start and f < (f_start + f_width): + x[bm_idx, f, t] = mask_value + + # Compute the number of masks over time axis + len_t = time_starts.shape[1] + # For all samples in the batch, apply the time mask + for b_idx in range(0, x.shape[0], threads_per_block): + # Resolve the index of the batch (case where more masks than MAX_THREAD_BUFFER) + bm_idx = b_idx + tid + + # Access mask only if valid sample id in batch + if bm_idx < x.shape[0]: + # For `len_t` number of freq masks that must be applied + for tidx in range(0, len_t): + # Access the start index and width of this time mask + t_start = time_starts[bm_idx, tidx] + t_width = time_widths[bm_idx, tidx] + + # If block idx `t` >= start and < (start + width) of this time mask + if t >= t_start and t < (t_start + t_width): + # Current block idx `t` < current seq length x_len[b] + # This ensure that we mask only upto the length of that sample + # Everything after that index is padded value so unnecessary to mask + if t < x_len[bm_idx]: + x[bm_idx, f, t] = mask_value + + +def spec_augment_launch_heuristics(x: torch.Tensor, length: torch.Tensor): + """ + Heuristics to determins whether pytorch implementation or numba implementation is selected. + Assumes numba cuda is supported. + + Args: + x: Torch tensor of shape [B, F, T] + length: Optional, Torch of tensor of shape [B] - containing lengths of the tensor. + + Returns: + True if numba kernel should be selected, else False + """ + if not x.is_cuda: + return False + + if length is None: + return False + + if x.shape[0] < 8: + return False + + return True + + +def launch_spec_augment_kernel( + x: torch.Tensor, + x_len: torch.Tensor, + freq_starts: torch.Tensor, + freq_lengths: torch.Tensor, + time_starts: torch.Tensor, + time_lengths: torch.Tensor, + freq_masks: int, + time_masks: int, + mask_value: float, +): + """ + Helper method to launch the SpecAugment kernel + + Args: + x: Pytorch tensor of shape [B, F, T] with the acoustic features. + x_len: Pytorch tensor of shape [B] with the lengths of the padded sequence. + freq_starts: Pytorch tensor of shape [B, M_f] with the start indices of freq masks. + freq_widths: Pytorch tensor of shape [B, M_f] with the width of freq masks. + time_starts: Pytorch tensor of shape [B, M_t] with the start indices of time masks. + time_widths: Pytorch tensor of shape [B, M_t] with the width of time masks. + freq_masks: Int value that determines the number of time masks. + time_masks: Int value that determines the number of freq masks. + mask_value: Float value that will be used as mask value. + + Returns: + The spec augmented tensor 'x' + """ + # Setup CUDA stream + sh = x.shape + stream = cuda.external_stream(torch.cuda.current_stream(x.device).cuda_stream) + + if time_masks > 0 or freq_masks > 0: + # Parallelize over freq and time axis, parallel threads over batch + # Sequential over masks (adaptive in time). + blocks_per_grid = [sh[1], sh[2]] + # threads_per_block = min(MAX_THREAD_BUFFER, max(freq_masks, time_masks)) + threads_per_block = min(MAX_THREAD_BUFFER, x.shape[0]) + + # Numba does not support fp16, force cast to fp32 temporarily at the expense of memory + original_dtype = x.dtype + cast_x = False + if x.dtype == torch.float16: + x = x.float() + cast_x = True + + # Launch CUDA kernel + spec_augment_kernel[blocks_per_grid, threads_per_block, stream, 0]( + x, x_len, freq_starts, freq_lengths, time_starts, time_lengths, mask_value + ) + torch.cuda.synchronize() + + # Recast back to original dtype if earlier cast was performed + if cast_x: + x = x.to(dtype=original_dtype) + + return x + + +class SpecAugmentNumba(nn.Module, Typing): + """ + Zeroes out(cuts) random continuous horisontal or + vertical segments of the spectrogram as described in + SpecAugment (https://arxiv.org/abs/1904.08779). + + Utilizes a Numba CUDA kernel to perform inplace edit of the input without loops. + Parallelize over freq and time axis, parallel threads over batch. + Sequential over masks (adaptive in time). + + Args: + freq_masks - how many frequency segments should be cut + time_masks - how many time segments should be cut + freq_width - maximum number of frequencies to be cut in one segment + time_width - maximum number of time steps to be cut in one segment. + Can be a positive integer or a float value in the range [0, 1]. + If positive integer value, defines maximum number of time steps + to be cut in one segment. + If a float value, defines maximum percentage of timesteps that + are cut adaptively. + rng: Ignored. + """ + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + "input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} + + def __init__( + self, freq_masks=0, time_masks=0, freq_width=10, time_width=0.1, rng=None, mask_value=0.0, + ): + super().__init__() + # Message to mention that numba specaugment kernel will be available + # if input device is CUDA and lengths are provided + logging.debug("Numba SpecAugment kernel is available") + + self.freq_masks = freq_masks + self.time_masks = time_masks + + self.freq_width = freq_width + self.time_width = time_width + + self.mask_value = mask_value + + # Unused + self.rng = rng + if self.rng is not None: + logging.warning("`rng` was supplied to SpecAugmentNumba, but it is not used.") + + if isinstance(time_width, int): + self.adaptive_temporal_width = False + else: + if time_width > 1.0 or time_width < 0.0: + raise ValueError('If `time_width` is a float value, must be in range [0, 1]') + + self.adaptive_temporal_width = True + + @typecheck() + @torch.no_grad() + def forward(self, input_spec, length): + sh = input_spec.shape + bs = sh[0] + + # Construct the freq and time masks as well as start positions + if self.freq_masks > 0: + freq_starts = torch.randint( + 0, sh[1] - self.freq_width + 1, size=[bs, self.freq_masks], device=input_spec.device + ) + freq_lengths = torch.randint(0, self.freq_width + 1, size=[bs, self.freq_masks], device=input_spec.device) + else: + freq_starts = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device) + freq_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device) + + if self.time_masks > 0: + if self.adaptive_temporal_width: + time_width = (length * self.time_width).int().clamp(min=1) + else: + time_width = ( + torch.tensor(self.time_width, dtype=torch.int32, device=input_spec.device) + .unsqueeze(0) + .repeat(sh[0]) + ) + + time_starts = [] + time_lengths = [] + for idx in range(sh[0]): + time_starts.append( + torch.randint( + 0, max(1, length[idx] - time_width[idx]), size=[1, self.time_masks], device=input_spec.device + ) + ) + time_lengths.append( + torch.randint(0, time_width[idx] + 1, size=[1, self.time_masks], device=input_spec.device) + ) + + time_starts = torch.cat(time_starts, 0) + time_lengths = torch.cat(time_lengths, 0) + + else: + time_starts = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device) + time_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device) + + x = launch_spec_augment_kernel( + input_spec, + length, + freq_starts=freq_starts, + freq_lengths=freq_lengths, + time_starts=time_starts, + time_lengths=time_lengths, + freq_masks=self.freq_masks, + time_masks=self.time_masks, + mask_value=self.mask_value, + ) + + return x diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba_utils.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5048aaf7486255f4d75c6e515756bd0b5c5b1d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/numba_utils.py @@ -0,0 +1,88 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from numba import jit + + +def phase_vocoder(D: np.ndarray, rate: float, phi_advance: np.ndarray, scale_buffer: np.ndarray): + """ + Optimized implementation of phase vocoder from Librosa. + Reference implementation: + - https://librosa.github.io/librosa/generated/librosa.core.phase_vocoder.html + Args: + D: Complex spectograms of shape [d, t, complex=2]. + rate: Speed rate, must be float greater than 0. + phi_advance: Precomputed phase advance buffer array of length [n_fft + 1] + scale_buffer: Precomputed numpy buffer array of length [n_fft + 1] + Returns: + Complex64 ndarray of shape [d, t / rate, complex=2] + """ + time_steps = np.arange(0, D.shape[1], rate, dtype=np.float) + + # Create an empty output array + d_stretch = np.zeros((D.shape[0], len(time_steps)), D.dtype, order='F') + + # Phase accumulator; initialize to the first sample + phase_acc = np.angle(D[:, 0]) + + # Pad 0 columns to simplify boundary logic + D = np.pad(D, [(0, 0), (0, 2)], mode='constant') + + d_stretch = _phase_vocoder_kernel(D, time_steps, phi_advance, d_stretch, phase_acc, scale_buffer) + + return d_stretch + + +@jit(nopython=True, nogil=True) +def _phase_vocoder_kernel(D, time_steps, phi_advance, d_stretch, phase_acc, scale_buffer): + """ + Numba optimized kernel to compute the phase vocoder step. + Args: + D: Complex spectograms of shape [d, t, complex=2]. + rate: Speed rate, must be float greater than 0. + time_steps: Numpy ndarray of linearly spaced time steps, shape = [t] + phi_advance: Precomputed phase advance buffer array of length [n_fft + 1] + d_stretch: Output complex matrix of shape [d, t / rate, complex=2] + phase_acc: Phase accumulator initialized to first sample of shape [d, complex=2] + scale_buffer: Precomputed numpy buffer array of length [n_fft + 1] + Returns: + Complex64 ndarray of shape [d, t / rate, complex=2] + """ + two_pi = 2.0 * np.pi + + for (t, step) in enumerate(time_steps): + columns = D[:, int(step) : int(step + 2)] + columns_0 = columns[:, 0] + columns_1 = columns[:, 1] + + # Weighting for linear magnitude interpolation + alpha = np.mod(step, 1.0) + mag = (1.0 - alpha) * np.abs(columns_0) + alpha * np.abs(columns_1) + + # Store to output array + d_stretch[:, t] = mag * np.exp(1.0j * phase_acc) + + # Compute phase advance + dphase = np.angle(columns_1) - np.angle(columns_0) - phi_advance + + # Wrap to -pi:pi range + scale = dphase / two_pi + np.round(scale, 0, scale_buffer) + + dphase = dphase - two_pi * scale_buffer + + # Accumulate phase + phase_acc += phi_advance + dphase + + return d_stretch diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/parsers.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..e4640bcdca516da14340136f870950dd0bf857d1 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/parsers.py @@ -0,0 +1,264 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import string +from typing import List, Optional + +import frozendict + +from nemo.collections.asr.parts import cleaners + + +class PhoneParser: + """Functor for parsing raw strings into list of int tokens. + + Examples: + >>> parser = PhoneParser(['AY1', 'HH', 'AE1']) + >>> parser(['AY1', 'HH', 'AE1']) + [0, 1, 2] + """ + + def __init__( + self, + labels: List[str], + *, + unk_id: int = -1, + blank_id: int = -1, + do_normalize: bool = False, + do_lowercase: bool = False, + add_end_space: bool = False + ): + """Creates simple mapping phone parser. + + Args: + labels: List of labels to allocate indexes for. Essentially, + this is a id to str mapping. + unk_id: Index to choose for OOV words (default: -1). + blank_id: Index to filter out from final list of tokens + (default: -1). + do_normalize: True if apply normalization step before tokenizing + (default: False). + do_lowercase: True if apply lowercasing at normalizing step + (default: False) + """ + + self._labels = labels + self._unk_id = unk_id + self._blank_id = blank_id + self._do_normalize = do_normalize + self._do_lowercase = do_lowercase + assert not self._do_normalize # because the input phone list is from the exactly same set + assert not self._do_lowercase # because the input phone list is from the exactly same set + + self._labels_map = {label: index for index, label in enumerate(labels)} + + def __call__(self, text: str) -> Optional[List[int]]: + text_tokens = self._tokenize(text) + return text_tokens + + def _tokenize(self, text: List) -> List[int]: + """This part is the important one to be modified""" + # here the text is actually a list of phone + tokens = [] + # Split by word for find special labels. + for phone in text: + tokens.append(self._labels_map.get(phone, self._unk_id)) + + # If unk_id == blank_id, OOV tokens are removed. + tokens = [token for token in tokens if token != self._blank_id] + + return tokens + + def decode(self, str_input): + r_map = {} + for k, v in self._labels_map.items(): + r_map[v] = k + r_map[len(self._labels_map)] = "" + r_map[len(self._labels_map) + 1] = "" + r_map[len(self._labels_map) + 2] = "

" + + out = [] + for i in str_input: + # Skip OOV + if i not in r_map: + continue + out.append(r_map[i.item()]) + + return " ".join(out) + + +class CharParser: + """Functor for parsing raw strings into list of int tokens. + + Examples: + >>> parser = CharParser(['a', 'b', 'c']) + >>> parser('abc') + [0, 1, 2] + """ + + def __init__( + self, + labels: List[str], + *, + unk_id: int = -1, + blank_id: int = -1, + do_normalize: bool = True, + do_lowercase: bool = True, + add_end_space: bool = False + ): + """Creates simple mapping char parser. + + Args: + labels: List of labels to allocate indexes for. Essentially, + this is a id to str mapping. + unk_id: Index to choose for OOV words (default: -1). + blank_id: Index to filter out from final list of tokens + (default: -1). + do_normalize: True if apply normalization step before tokenizing + (default: True). + do_lowercase: True if apply lowercasing at normalizing step + (default: True). + """ + + self._labels = labels + self._unk_id = unk_id + self._blank_id = blank_id + self._do_normalize = do_normalize + self._do_lowercase = do_lowercase + + self._labels_map = {label: index for index, label in enumerate(labels)} + self._special_labels = set([label for label in labels if len(label) > 1]) + + print('INFO: CharParser add_end_space: {}'.format(add_end_space)) + self.add_end_space = add_end_space + + def __call__(self, text: str) -> Optional[List[int]]: + if self._do_normalize: + text = self._normalize(text) + if text is None: + return None + + text_tokens = self._tokenize(text) + + return text_tokens + + def _normalize(self, text: str) -> Optional[str]: + text = text.strip() + + if self._do_lowercase: + text = text.lower() + + return text + + def _tokenize(self, text: str) -> List[int]: + tokens = [] + # Split by word for find special labels. + for word_id, word in enumerate(text.split(' ')): + if word_id != 0 and not self.add_end_space: # Not first word - so we insert space before. + tokens.append(self._labels_map.get(' ', self._unk_id)) + + if word in self._special_labels: + tokens.append(self._labels_map[word]) + continue + + for char in word: + tokens.append(self._labels_map.get(char, self._unk_id)) + + if self.add_end_space: + tokens.append(self._labels_map.get(' ', self._unk_id)) + + # If unk_id == blank_id, OOV tokens are removed. + tokens = [token for token in tokens if token != self._blank_id] + + return tokens + + +class ENCharParser(CharParser): + """Incorporates english-specific parsing logic.""" + + PUNCTUATION_TO_REPLACE = frozendict.frozendict({'+': 'plus', '&': 'and', '%': 'percent'}) + + def __init__(self, *args, **kwargs): + """Creates english-specific mapping char parser. + + This class overrides normalizing implementation. + + Args: + *args: Positional args to pass to `CharParser` constructor. + **kwargs: Key-value args to pass to `CharParser` constructor. + """ + + super().__init__(*args, **kwargs) + + self._table = self.__make_trans_table() + + def __make_trans_table(self): + punctuation = string.punctuation + + for char in self.PUNCTUATION_TO_REPLACE: + punctuation = punctuation.replace(char, '') + + for label in self._labels: + punctuation = punctuation.replace(label, '') + + table = str.maketrans(punctuation, ' ' * len(punctuation)) + + return table + + def _normalize(self, text: str) -> Optional[str]: + # noinspection PyBroadException + try: + text = cleaners.clean_text( + string=text, table=self._table, punctuation_to_replace=self.PUNCTUATION_TO_REPLACE, + ) + except Exception: + return None + + return text + + +NAME_TO_PARSER = frozendict.frozendict({'base': CharParser, 'en': ENCharParser}) + + +def make_parser(labels: Optional[List[str]] = None, name: str = 'base', **kwargs, ) -> CharParser: + """Creates parser from labels, set of arguments and concise parser name. + + Args: + labels: List of labels to allocate indexes for. If set to + None then labels would be ascii table list. Essentially, this is a + id to str mapping (default: None). + name: Concise name of parser to create (default: 'base'). + (default: -1). + **kwargs: Other set of kwargs to pass to parser constructor. + + Returns: + Instance of `CharParser`. + + Raises: + ValueError: For invalid parser name. + + Examples: + >>> type(make_parser(['a', 'b', 'c'], 'en')) + ENCharParser + """ + + if name not in NAME_TO_PARSER: + raise ValueError('Invalid parser name.') + + if labels is None: + labels = list(string.printable) + + parser_type = NAME_TO_PARSER[name] + parser = parser_type(labels=labels, **kwargs) + + return parser diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/perturb.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/perturb.py new file mode 100644 index 0000000000000000000000000000000000000000..310b4a859e99fd85999554eee0cb2bf426090675 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/perturb.py @@ -0,0 +1,1028 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (c) 2018 Ryan Leary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# This file contains code artifacts adapted from https://github.com/ryanleary/patter +import copy +import io +import os +import random +import subprocess +from tempfile import NamedTemporaryFile +from typing import List, Optional, Union + +import librosa +import numpy as np +import pandas +import soundfile as sf +import sox +import webdataset as wd +from omegaconf import DictConfig, OmegaConf +from scipy import signal +from torch.utils.data import IterableDataset + +from nemo.collections.asr.parts import collections, parsers +from nemo.collections.asr.parts.segment import AudioSegment +from nemo.utils import logging + +try: + from nemo.collections.asr.parts import numba_utils + + HAVE_NUMBA = True +except (ImportError, ModuleNotFoundError): + HAVE_NUMBA = False + + +def read_one_audiosegment(manifest, target_sr, rng, tarred_audio=False, audio_dataset=None): + + if tarred_audio: + if audio_dataset is None: + raise TypeError("Expected augmentation dataset but got None") + audio_file, file_id = next(audio_dataset) + manifest_idx = manifest.mapping[file_id] + manifest_entry = manifest[manifest_idx] + + offset = 0 if manifest_entry.offset is None else manifest_entry.offset + duration = 0 if manifest_entry.duration is None else manifest_entry.duration + + else: + audio_record = rng.sample(manifest.data, 1)[0] + audio_file = audio_record.audio_file + offset = 0 if audio_record.offset is None else audio_record.offset + duration = 0 if audio_record.duration is None else audio_record.duration + + return AudioSegment.from_file(audio_file, target_sr=target_sr, offset=offset, duration=duration) + + +class Perturbation(object): + def max_augmentation_length(self, length): + return length + + def perturb(self, data): + raise NotImplementedError + + +class SpeedPerturbation(Perturbation): + def __init__(self, sr, resample_type, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=5, rng=None): + """ + Performs Speed Augmentation by re-sampling the data to a different sampling rate, + which does not preserve pitch. + + Note: This is a very slow operation for online augmentation. If space allows, + it is preferable to pre-compute and save the files to augment the dataset. + + Args: + sr: Original sampling rate. + resample_type: Type of resampling operation that will be performed. + For better speed using `resampy`'s fast resampling method, use `resample_type='kaiser_fast'`. + For high-quality resampling, set `resample_type='kaiser_best'`. + To use `scipy.signal.resample`, set `resample_type='fft'` or `resample_type='scipy'` + min_speed_rate: Minimum sampling rate modifier. + max_speed_rate: Maximum sampling rate modifier. + num_rates: Number of discrete rates to allow. Can be a positive or negative + integer. + If a positive integer greater than 0 is provided, the range of + speed rates will be discretized into `num_rates` values. + If a negative integer or 0 is provided, the full range of speed rates + will be sampled uniformly. + Note: If a positive integer is provided and the resultant discretized + range of rates contains the value '1.0', then those samples with rate=1.0, + will not be augmented at all and simply skipped. This is to unnecessary + augmentation and increase computation time. Effective augmentation chance + in such a case is = `prob * (num_rates - 1 / num_rates) * 100`% chance + where `prob` is the global probability of a sample being augmented. + rng: Random seed number. + """ + min_rate = min(min_speed_rate, max_speed_rate) + if min_rate < 0.0: + raise ValueError("Minimum sampling rate modifier must be > 0.") + + if resample_type not in ('kaiser_best', 'kaiser_fast', 'fft', 'scipy'): + raise ValueError("Supported `resample_type` values are ('kaiser_best', 'kaiser_fast', 'fft', 'scipy')") + + self._sr = sr + self._min_rate = min_speed_rate + self._max_rate = max_speed_rate + self._num_rates = num_rates + if num_rates > 0: + self._rates = np.linspace(self._min_rate, self._max_rate, self._num_rates, endpoint=True) + self._res_type = resample_type + self._rng = random.Random() if rng is None else rng + + def max_augmentation_length(self, length): + return length * self._max_rate + + def perturb(self, data): + # Select speed rate either from choice or random sample + if self._num_rates < 0: + speed_rate = self._rng.uniform(self._min_rate, self._max_rate) + else: + speed_rate = self._rng.choice(self._rates) + + # Skip perturbation in case of identity speed rate + if speed_rate == 1.0: + return + + new_sr = int(self._sr * speed_rate) + data._samples = librosa.core.resample(data._samples, self._sr, new_sr, res_type=self._res_type) + + +class TimeStretchPerturbation(Perturbation): + def __init__(self, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=5, n_fft=512, rng=None): + """ + Time-stretch an audio series by a fixed rate while preserving pitch, based on [1, 2]. + + Note: + This is a simplified implementation, intended primarily for reference and pedagogical purposes. + It makes no attempt to handle transients, and is likely to produce audible artifacts. + + Reference + [1] [Ellis, D. P. W. ā€œA phase vocoder in Matlab.ā€ Columbia University, 2002.] + (http://www.ee.columbia.edu/~dpwe/resources/matlab/pvoc/) + [2] [librosa.effects.time_stretch] + (https://librosa.github.io/librosa/generated/librosa.effects.time_stretch.html) + + Args: + min_speed_rate: Minimum sampling rate modifier. + max_speed_rate: Maximum sampling rate modifier. + num_rates: Number of discrete rates to allow. Can be a positive or negative + integer. + If a positive integer greater than 0 is provided, the range of + speed rates will be discretized into `num_rates` values. + If a negative integer or 0 is provided, the full range of speed rates + will be sampled uniformly. + Note: If a positive integer is provided and the resultant discretized + range of rates contains the value '1.0', then those samples with rate=1.0, + will not be augmented at all and simply skipped. This is to avoid unnecessary + augmentation and increase computation time. Effective augmentation chance + in such a case is = `prob * (num_rates - 1 / num_rates) * 100`% chance + where `prob` is the global probability of a sample being augmented. + n_fft: Number of fft filters to be computed. + rng: Random seed number. + """ + min_rate = min(min_speed_rate, max_speed_rate) + if min_rate < 0.0: + raise ValueError("Minimum sampling rate modifier must be > 0.") + + self._min_rate = min_speed_rate + self._max_rate = max_speed_rate + self._num_rates = num_rates + if num_rates > 0: + self._rates = np.linspace(self._min_rate, self._max_rate, self._num_rates, endpoint=True) + self._rng = random.Random() if rng is None else rng + + # Pre-compute constants + self._n_fft = int(n_fft) + self._hop_length = int(n_fft // 2) + + # Pre-allocate buffers + self._phi_advance_fast = np.linspace(0, np.pi * self._hop_length, self._hop_length + 1) + self._scale_buffer_fast = np.empty(self._hop_length + 1, dtype=np.float32) + + self._phi_advance_slow = np.linspace(0, np.pi * self._n_fft, self._n_fft + 1) + self._scale_buffer_slow = np.empty(self._n_fft + 1, dtype=np.float32) + + def max_augmentation_length(self, length): + return length * self._max_rate + + def perturb(self, data): + # Select speed rate either from choice or random sample + if self._num_rates < 0: + speed_rate = self._rng.uniform(self._min_rate, self._max_rate) + else: + speed_rate = self._rng.choice(self._rates) + + # Skip perturbation in case of identity speed rate + if speed_rate == 1.0: + return + + # Increase `n_fft` based on task (speed up or slow down audio) + # This greatly reduces upper bound of maximum time taken + # to compute slowed down audio segments. + if speed_rate >= 1.0: # Speed up audio + fft_multiplier = 1 + phi_advance = self._phi_advance_fast + scale_buffer = self._scale_buffer_fast + + else: # Slow down audio + fft_multiplier = 2 + phi_advance = self._phi_advance_slow + scale_buffer = self._scale_buffer_slow + + n_fft = int(self._n_fft * fft_multiplier) + hop_length = int(self._hop_length * fft_multiplier) + + # Perform short-term Fourier transform (STFT) + stft = librosa.core.stft(data._samples, n_fft=n_fft, hop_length=hop_length) + + # Stretch by phase vocoding + if HAVE_NUMBA: + stft_stretch = numba_utils.phase_vocoder(stft, speed_rate, phi_advance, scale_buffer) + + else: + stft_stretch = librosa.core.phase_vocoder(stft, speed_rate, hop_length) + + # Predict the length of y_stretch + len_stretch = int(round(len(data._samples) / speed_rate)) + + # Invert the STFT + y_stretch = librosa.core.istft( + stft_stretch, dtype=data._samples.dtype, hop_length=hop_length, length=len_stretch + ) + + data._samples = y_stretch + + +class TempoPerturbation(Perturbation): + def __init__(self, factors, rng=None): + assert len(factors) > 0 + assert min(factors) > 0 + self.factors = factors + self._max_factor = max(self.factors) + self._rng = random.Random() if rng is None else rng + + def max_augmentation_length(self, length): + return length * self._max_factor + + def perturb(self, data): + speed_rate = self._rng.choice(self.factors) + if speed_rate == 1.0: + return + + tfm = sox.Transformer() + if abs(speed_rate - 1.0) <= 0.1: + tfm.stretch(speed_rate) + else: + tfm.tempo(speed_rate) + perturbed_data = tfm.build_array(input_array=data._samples, sample_rate_in=data._sample_rate) + + data._samples = perturbed_data + + +class GainPerturbation(Perturbation): + """ + Applies random gain to the audio. + + Args: + min_gain_dbfs (float): Min gain level in dB + max_gain_dbfs (float): Max gain level in dB + rng: Random number generator + """ + + def __init__(self, min_gain_dbfs=-10, max_gain_dbfs=10, rng=None): + self._min_gain_dbfs = min_gain_dbfs + self._max_gain_dbfs = max_gain_dbfs + self._rng = random.Random() if rng is None else rng + + def perturb(self, data): + gain = self._rng.uniform(self._min_gain_dbfs, self._max_gain_dbfs) + # logging.debug("gain: %d", gain) + data._samples = data._samples * (10.0 ** (gain / 20.0)) + + +class ImpulsePerturbation(Perturbation): + """ + Convolves audio with a Room Impulse Response. + + Args: + manifest_path (list): Manifest file for RIRs + audio_tar_filepaths (list): Tar files, if RIR audio files are tarred + shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files + shift_impulse (bool): Shift impulse response to adjust for delay at the beginning + """ + + def __init__(self, manifest_path=None, rng=None, audio_tar_filepaths=None, shuffle_n=128, shift_impulse=False): + self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True) + self._audiodataset = None + self._tarred_audio = False + self._shift_impulse = shift_impulse + self._data_iterator = None + + if audio_tar_filepaths: + self._tarred_audio = True + self._audiodataset = AugmentationDataset(manifest_path, audio_tar_filepaths, shuffle_n) + self._data_iterator = iter(self._audiodataset) + + self._rng = random.Random() if rng is None else rng + + def perturb(self, data): + impulse = read_one_audiosegment( + self._manifest, + data.sample_rate, + self._rng, + tarred_audio=self._tarred_audio, + audio_dataset=self._data_iterator, + ) + if not self._shift_impulse: + impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples)) + data._samples = signal.fftconvolve(data._samples, impulse_norm, "same") + else: + # Find peak and shift peak to left + impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples)) + max_ind = np.argmax(np.abs(impulse_norm)) + + impulse_resp = impulse_norm[max_ind:] + delay_after = len(impulse_resp) + data._samples = signal.fftconvolve(data._samples, impulse_resp, "full")[:-delay_after] + + +class ShiftPerturbation(Perturbation): + """ + Perturbs audio by shifting the audio in time by a random amount between min_shift_ms and max_shift_ms. + The final length of the audio is kept unaltered by padding the audio with zeros. + + + Args: + min_shift_ms (float): Minimum time in milliseconds by which audio will be shifted + max_shift_ms (float): Maximum time in milliseconds by which audio will be shifted + rng: Random number generator + """ + + def __init__(self, min_shift_ms=-5.0, max_shift_ms=5.0, rng=None): + self._min_shift_ms = min_shift_ms + self._max_shift_ms = max_shift_ms + self._rng = random.Random() if rng is None else rng + + def perturb(self, data): + shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms) + if abs(shift_ms) / 1000 > data.duration: + # TODO: do something smarter than just ignore this condition + return + shift_samples = int(shift_ms * data.sample_rate // 1000) + # logging.debug("shift: %s", shift_samples) + if shift_samples < 0: + data._samples[-shift_samples:] = data._samples[:shift_samples] + data._samples[:-shift_samples] = 0 + elif shift_samples > 0: + data._samples[:-shift_samples] = data._samples[shift_samples:] + data._samples[-shift_samples:] = 0 + + +class NoisePerturbation(Perturbation): + """ + Perturbation that adds noise to input audio. + + Args: + manifest_path (str): Manifest file with paths to noise files + min_snr_db (float): Minimum SNR of audio after noise is added + max_snr_db (float): Maximum SNR of audio after noise is added + max_gain_db (float): Maximum gain that can be applied on the noise sample + audio_tar_filepaths (list) : Tar files, if noise audio files are tarred + shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files + orig_sr (int): Original sampling rate of the noise files + rng: Random number generator + """ + + def __init__( + self, + manifest_path=None, + min_snr_db=10, + max_snr_db=50, + max_gain_db=300.0, + rng=None, + audio_tar_filepaths=None, + shuffle_n=100, + orig_sr=16000, + ): + self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True) + self._audiodataset = None + self._tarred_audio = False + self._orig_sr = orig_sr + self._data_iterator = None + + if audio_tar_filepaths: + self._tarred_audio = True + self._audiodataset = AugmentationDataset(manifest_path, audio_tar_filepaths, shuffle_n) + self._data_iterator = iter(self._audiodataset) + + self._rng = random.Random() if rng is None else rng + self._min_snr_db = min_snr_db + self._max_snr_db = max_snr_db + self._max_gain_db = max_gain_db + + @property + def orig_sr(self): + return self._orig_sr + + def get_one_noise_sample(self, target_sr): + return read_one_audiosegment( + self._manifest, target_sr, self._rng, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator + ) + + def perturb(self, data): + noise = read_one_audiosegment( + self._manifest, + data.sample_rate, + self._rng, + tarred_audio=self._tarred_audio, + audio_dataset=self._data_iterator, + ) + self.perturb_with_input_noise(data, noise) + + def perturb_with_input_noise(self, data, noise, data_rms=None): + snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db) + if data_rms is None: + data_rms = data.rms_db + noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db) + # logging.debug("noise: %s %s %s", snr_db, noise_gain_db, noise_record.audio_file) + + # calculate noise segment to use + start_time = self._rng.uniform(0.0, noise.duration - data.duration) + if noise.duration > (start_time + data.duration): + noise.subsegment(start_time=start_time, end_time=start_time + data.duration) + + # adjust gain for snr purposes and superimpose + noise.gain_db(noise_gain_db) + + if noise._samples.shape[0] < data._samples.shape[0]: + noise_idx = self._rng.randint(0, data._samples.shape[0] - noise._samples.shape[0]) + data._samples[noise_idx : noise_idx + noise._samples.shape[0]] += noise._samples + + else: + data._samples += noise._samples + + def perturb_with_foreground_noise( + self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1, + ): + snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db) + if not data_rms: + data_rms = data.rms_db + + noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db) + n_additions = self._rng.randint(1, max_additions) + + for i in range(n_additions): + noise_dur = self._rng.uniform(0.0, max_noise_dur) + start_time = self._rng.uniform(0.0, noise.duration) + start_sample = int(round(start_time * noise.sample_rate)) + end_sample = int(round(min(noise.duration, (start_time + noise_dur)) * noise.sample_rate)) + noise_samples = np.copy(noise._samples[start_sample:end_sample]) + # adjust gain for snr purposes and superimpose + noise_samples *= 10.0 ** (noise_gain_db / 20.0) + + if noise_samples.shape[0] > data._samples.shape[0]: + noise_samples = noise_samples[0 : data._samples.shape[0]] + + noise_idx = self._rng.randint(0, data._samples.shape[0] - noise_samples.shape[0]) + data._samples[noise_idx : noise_idx + noise_samples.shape[0]] += noise_samples + + +class RandomNoisePerturbation(Perturbation): + """ + Perturbation that adds noise to input audio. + + Args: + manifest_path (str): Manifest file with paths to noise files + min_snr_db (float): Minimum SNR of audio after noise is added + max_snr_db (float): Maximum SNR of audio after noise is added + max_gain_db (float): Maximum gain that can be applied on the noise sample + audio_tar_filepaths (list) : Tar files, if noise audio files are tarred + shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files + orig_sr (int): Original sampling rate of the noise files + rng: Random number generator + """ + + def __init__( + self, + manifest_path=None, + min_snr_db=10, + max_snr_db=50, + max_gain_db=300.0, + ratio=1.0, + rng=None, + target_sr=16000, + data_dir='', + cache_noise=False + ): + + self._rng = random.Random() if rng is None else rng + self._min_snr_db = min_snr_db + self._max_snr_db = max_snr_db + self._max_gain_db = max_gain_db + self.ratio = ratio + + self.data_dir = data_dir + self.target_sr = target_sr + manifest, self._noise_weights = self.read_manifest(manifest_path) + self.noise_files = manifest['wav_filename'].tolist() + self._cache = {} + self.cache_noise = cache_noise + + def read_manifest(self, manifest_fps): + manifest_files = [] + for fp in manifest_fps: + manifest_files.append(pandas.read_csv(fp, encoding='utf-8')) + manifest = pandas.concat(manifest_files) + + orig_noise_num = len(manifest) + wav_header_size = 44 + # only use noise with duration longer than 1 second + manifest = manifest[manifest['wav_filesize'] > (1 * 16000 * 2 + wav_header_size)] + print('filter noise less than 1s: from {} to {} samples'.format(orig_noise_num, len(manifest))) + wav_data_size = manifest['wav_filesize'].values - wav_header_size + print('noise duration sum: {}h'.format(wav_data_size.sum() / (16000 * 2) / 60 / 60)) + noise_weights = wav_data_size / wav_data_size.sum() + return manifest, noise_weights.tolist() + + def perturb(self, data): + if self._rng.random() < self.ratio: + noises = self.get_noises(data.num_samples) + self.perturb_with_input_noise(data, noises) + + def perturb_with_input_noise(self, data, noises): + snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db) + data_rms = rms_db(data._samples) + + start_index = 0 + for noise_i in noises: + noise_gain_db = min(data_rms - rms_db(noise_i) - snr_db, self._max_gain_db) + # logging.debug("noise: %s %s %s", snr_db, noise_gain_db, noise_record.audio_file) + + # adjust gain for snr purposes and superimpose + noise_i = gain_db(noise_i, noise_gain_db) + + end_index = start_index + noise_i.shape[0] + data._samples[start_index:end_index] += noise_i + start_index = end_index + assert end_index == data.num_samples + + def get_noises(self, num_samples): + left_noise_samples = num_samples + noise_data = [] + while left_noise_samples > 0: + noise = self.read_one_noise() + if noise.shape[0] > left_noise_samples: + start_pos = self._rng.randrange(0, noise.shape[0] - left_noise_samples + 1) + noise = noise[start_pos:start_pos + left_noise_samples] + left_noise_samples -= noise.shape[0] + noise_data.append(noise) + assert left_noise_samples == 0 + return noise_data + + def read_one_noise(self): + fp = self._rng.choices(self.noise_files, weights=self._noise_weights)[0] + fp = os.path.join(self.data_dir, fp) + + cached_noise = self._cache.get(fp, None) + if cached_noise is None: + cached_noise = AudioSegment.from_file(fp, target_sr=self.target_sr)._samples + if self.cache_noise: + self._cache[fp] = cached_noise + return cached_noise.copy() + + +def rms_db(samples): + mean_square = np.mean(samples ** 2) + if mean_square == 0: + return -np.inf + else: + return 10 * np.log10(mean_square) + + +def gain_db(samples, gain): + return samples * (10.0 ** (gain / 20.0)) + + +class WhiteNoisePerturbation(Perturbation): + """ + Perturbation that adds white noise to an audio file in the training dataset. + + Args: + min_level (int): Minimum level in dB at which white noise should be added + max_level (int): Maximum level in dB at which white noise should be added + rng: Random number generator + """ + + def __init__(self, min_level=-90, max_level=-46, rng=None): + self.min_level = int(min_level) + self.max_level = int(max_level) + self._rng = np.random.RandomState() if rng is None else rng + + def perturb(self, data): + noise_level_db = self._rng.randint(self.min_level, self.max_level, dtype='int32') + noise_signal = self._rng.randn(data._samples.shape[0]) * (10.0 ** (noise_level_db / 20.0)) + data._samples += noise_signal + + +class RirAndNoisePerturbation(Perturbation): + def __init__( + self, + rir_manifest_path=None, + rir_prob=0.5, + noise_manifest_paths=None, + min_snr_db=0, + max_snr_db=50, + rir_tar_filepaths=None, + rir_shuffle_n=100, + noise_tar_filepaths=None, + apply_noise_rir=False, + orig_sample_rate=None, + max_additions=5, + max_duration=2.0, + bg_noise_manifest_paths=None, + bg_min_snr_db=10, + bg_max_snr_db=50, + bg_noise_tar_filepaths=None, + bg_orig_sample_rate=None, + ): + """ + RIR augmentation with additive foreground and background noise. + In this implementation audio data is augmented by first convolving the audio with a Room Impulse Response + and then adding foreground noise and background noise at various SNRs. RIR, foreground and background noises + should either be supplied with a manifest file or as tarred audio files (faster). + + Different sets of noise audio files based on the original sampling rate of the noise. This is useful while + training a mixed sample rate model. For example, when training a mixed model with 8 kHz and 16 kHz audio with a + target sampling rate of 16 kHz, one would want to augment 8 kHz data with 8 kHz noise rather than 16 kHz noise. + + Args: + rir_manifest_path: manifest file for RIRs + rir_tar_filepaths: tar files, if RIR audio files are tarred + rir_prob: probability of applying a RIR + noise_manifest_paths: foreground noise manifest path + min_snr_db: min SNR for foreground noise + max_snr_db: max SNR for background noise, + noise_tar_filepaths: tar files, if noise files are tarred + apply_noise_rir: whether to convolve foreground noise with a a random RIR + orig_sample_rate: original sampling rate of foreground noise audio + max_additions: max number of times foreground noise is added to an utterance, + max_duration: max duration of foreground noise + bg_noise_manifest_paths: background noise manifest path + bg_min_snr_db: min SNR for background noise + bg_max_snr_db: max SNR for background noise + bg_noise_tar_filepaths: tar files, if noise files are tarred + bg_orig_sample_rate: original sampling rate of background noise audio + + """ + logging.info("Called Rir aug init") + self._rir_prob = rir_prob + self._rng = random.Random() + self._rir_perturber = ImpulsePerturbation( + manifest_path=rir_manifest_path, + audio_tar_filepaths=rir_tar_filepaths, + shuffle_n=rir_shuffle_n, + shift_impulse=True, + ) + self._fg_noise_perturbers = {} + self._bg_noise_perturbers = {} + if noise_manifest_paths: + for i in range(len(noise_manifest_paths)): + if orig_sample_rate is None: + orig_sr = 16000 + else: + orig_sr = orig_sample_rate[i] + self._fg_noise_perturbers[orig_sr] = NoisePerturbation( + manifest_path=noise_manifest_paths[i], + min_snr_db=min_snr_db[i], + max_snr_db=max_snr_db[i], + audio_tar_filepaths=noise_tar_filepaths[i], + orig_sr=orig_sr, + ) + self._max_additions = max_additions + self._max_duration = max_duration + if bg_noise_manifest_paths: + for i in range(len(bg_noise_manifest_paths)): + if bg_orig_sample_rate is None: + orig_sr = 16000 + else: + orig_sr = bg_orig_sample_rate[i] + self._bg_noise_perturbers[orig_sr] = NoisePerturbation( + manifest_path=bg_noise_manifest_paths[i], + min_snr_db=bg_min_snr_db[i], + max_snr_db=bg_max_snr_db[i], + audio_tar_filepaths=bg_noise_tar_filepaths[i], + orig_sr=orig_sr, + ) + + self._apply_noise_rir = apply_noise_rir + + def perturb(self, data): + prob = self._rng.uniform(0.0, 1.0) + + if prob < self._rir_prob: + self._rir_perturber.perturb(data) + + orig_sr = data.orig_sr + if orig_sr not in self._fg_noise_perturbers: + orig_sr = max(self._fg_noise_perturbers.keys()) + fg_perturber = self._fg_noise_perturbers[orig_sr] + + orig_sr = data.orig_sr + if orig_sr not in self._bg_noise_perturbers: + orig_sr = max(self._bg_noise_perturbers.keys()) + bg_perturber = self._bg_noise_perturbers[orig_sr] + + data_rms = data.rms_db + noise = fg_perturber.get_one_noise_sample(data.sample_rate) + if self._apply_noise_rir: + self._rir_perturber.perturb(noise) + fg_perturber.perturb_with_foreground_noise( + data, noise, data_rms=data_rms, max_noise_dur=self._max_duration, max_additions=self._max_additions + ) + noise = bg_perturber.get_one_noise_sample(data.sample_rate) + bg_perturber.perturb_with_input_noise(data, noise, data_rms=data_rms) + + +class TranscodePerturbation(Perturbation): + def __init__(self, rng=None): + """ + Audio codec augmentation. This implementation uses sox to transcode audio with low rate audio codecs, + so users need to make sure that the installed sox version supports the codecs used here (G711 and amr-nb). + + """ + self._rng = np.random.RandomState() if rng is None else rng + self._codecs = ["g711", "amr-nb"] + + def perturb(self, data): + att_factor = 0.8 + max_level = np.max(np.abs(data._samples)) + norm_factor = att_factor / max_level + norm_samples = norm_factor * data._samples + orig_f = NamedTemporaryFile(suffix=".wav") + sf.write(orig_f.name, norm_samples.transpose(), 16000) + + codec_ind = random.randint(0, len(self._codecs) - 1) + if self._codecs[codec_ind] == "amr-nb": + transcoded_f = NamedTemporaryFile(suffix="_amr.wav") + rates = list(range(0, 8)) + rate = rates[random.randint(0, len(rates) - 1)] + _ = subprocess.check_output( + f"sox {orig_f.name} -V0 -C {rate} -t amr-nb - | sox -t amr-nb - -V0 -b 16 -r 16000 {transcoded_f.name}", + shell=True, + ) + elif self._codecs[codec_ind] == "g711": + transcoded_f = NamedTemporaryFile(suffix="_g711.wav") + _ = subprocess.check_output( + f"sox {orig_f.name} -V0 -r 8000 -c 1 -e a-law {transcoded_f.name}", shell=True + ) + + new_data = AudioSegment.from_file(transcoded_f.name, target_sr=16000) + data._samples = new_data._samples[0 : data._samples.shape[0]] + return + + +perturbation_types = { + "speed": SpeedPerturbation, + "time_stretch": TimeStretchPerturbation, + "tempo": TempoPerturbation, + "gain": GainPerturbation, + "impulse": ImpulsePerturbation, + "shift": ShiftPerturbation, + "noise": NoisePerturbation, + "white_noise": WhiteNoisePerturbation, + "rir_noise_aug": RirAndNoisePerturbation, + "transcode_aug": TranscodePerturbation, +} + + +def register_perturbation(name: str, perturbation: Perturbation): + if name in perturbation_types.keys(): + raise KeyError( + f"Perturbation with the name {name} exists. " f"Type of perturbation : {perturbation_types[name]}." + ) + + perturbation_types[name] = perturbation + + +class AudioAugmentor(object): + def __init__(self, perturbations=None, rng=None): + self._rng = random.Random() if rng is None else rng + self._pipeline = perturbations if perturbations is not None else [] + if perturbations: + print('audio perturbations:', perturbations) + + def perturb(self, segment): + for (prob, p) in self._pipeline: + if self._rng.random() < prob: + p.perturb(segment) + return + + def max_augmentation_length(self, length): + newlen = length + for (prob, p) in self._pipeline: + newlen = p.max_augmentation_length(newlen) + return newlen + + @classmethod + def from_config(cls, config): + ptbs = [] + for p in config: + if p['aug_type'] not in perturbation_types: + logging.warning("%s perturbation not known. Skipping.", p['aug_type']) + continue + perturbation = perturbation_types[p['aug_type']] + ptbs.append((p['prob'], perturbation(**p['cfg']))) + return cls(perturbations=ptbs) + + +def process_augmentations(augmenter) -> Optional[AudioAugmentor]: + """Process list of online data augmentations. + Accepts either an AudioAugmentor object with pre-defined augmentations, + or a dictionary that points to augmentations that have been defined. + If a dictionary is passed, must follow the below structure: + Dict[str, Dict[str, Any]]: Which refers to a dictionary of string + names for augmentations, defined in `asr/parts/perturb.py`. + The inner dictionary may contain key-value arguments of the specific + augmentation, along with an essential key `prob`. `prob` declares the + probability of the augmentation being applied, and must be a float + value in the range [0, 1]. + # Example in YAML config file + Augmentations are generally applied only during training, so we can add + these augmentations to our yaml config file, and modify the behaviour + for training and evaluation. + ```yaml + AudioToSpeechLabelDataLayer: + ... # Parameters shared between train and evaluation time + train: + augmentor: + shift: + prob: 0.5 + min_shift_ms: -5.0 + max_shift_ms: 5.0 + white_noise: + prob: 1.0 + min_level: -90 + max_level: -46 + ... + eval: + ... + ``` + Then in the training script, + ```python + import copy + from ruamel.yaml import YAML + yaml = YAML(typ="safe") + with open(model_config) as f: + params = yaml.load(f) + # Train Config for Data Loader + train_dl_params = copy.deepcopy(params["AudioToTextDataLayer"]) + train_dl_params.update(params["AudioToTextDataLayer"]["train"]) + del train_dl_params["train"] + del train_dl_params["eval"] + data_layer_train = nemo_asr.AudioToTextDataLayer( + ..., + **train_dl_params, + ) + # Evaluation Config for Data Loader + eval_dl_params = copy.deepcopy(params["AudioToTextDataLayer"]) + eval_dl_params.update(params["AudioToTextDataLayer"]["eval"]) + del eval_dl_params["train"] + del eval_dl_params["eval"] + data_layer_eval = nemo_asr.AudioToTextDataLayer( + ..., + **eval_dl_params, + ) + ``` + # Registering your own Augmentations + To register custom augmentations to obtain the above convenience of + the declaring the augmentations in YAML, you can put additional keys in + `perturbation_types` dictionary as follows. + ```python + from nemo.collections.asr.parts import perturb + # Define your own perturbation here + class CustomPerturbation(perturb.Perturbation): + ... + perturb.register_perturbation(name_of_perturbation, CustomPerturbation) + ``` + Args: + augmenter: AudioAugmentor object or + dictionary of str -> kwargs (dict) which is parsed and used + to initialize an AudioAugmentor. + Note: It is crucial that each individual augmentation has + a keyword `prob`, that defines a float probability in the + the range [0, 1] of this augmentation being applied. + If this keyword is not present, then the augmentation is + disabled and a warning is logged. + Returns: AudioAugmentor object + """ + if augmenter is None: + return None + + if isinstance(augmenter, AudioAugmentor): + return augmenter + + if not type(augmenter) in {dict, DictConfig}: + raise ValueError("Cannot parse augmenter. Must be a dict or an AudioAugmentor object ") + + if isinstance(augmenter, DictConfig): + augmenter = OmegaConf.to_container(augmenter, resolve=True) + + augmenter = copy.deepcopy(augmenter) + + augmentations = [] + for augment_name, augment_kwargs in augmenter.items(): + prob = augment_kwargs.get('prob', None) + + if prob is None: + raise KeyError( + f'Augmentation "{augment_name}" will not be applied as ' + f'keyword argument "prob" was not defined for this augmentation.' + ) + + else: + _ = augment_kwargs.pop('prob') + + if prob < 0.0 or prob > 1.0: + raise ValueError("`prob` must be a float value between 0 and 1.") + + try: + augmentation = perturbation_types[augment_name](**augment_kwargs) + augmentations.append([prob, augmentation]) + except KeyError: + raise KeyError(f"Invalid perturbation name. Allowed values : {perturbation_types.keys()}") + + augmenter = AudioAugmentor(perturbations=augmentations) + return augmenter + + +class AugmentationDataset(IterableDataset): + """ + A class that loads tarred audio files and cycles over the files in the dataset. + + Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset), + as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should + contain the information for one audio file, including at least the transcript and name of the audio + file within the tarball. + + Valid formats for the audio_tar_filepaths argument include: + (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or + (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...]. + + Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference. + This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements. + Supported opening braces - { <=> (, [, < and the special tag _OP_. + Supported closing braces - } <=> ), ], > and the special tag _CL_. + For SLURM based tasks, we suggest the use of the special tags for ease of use. + + See the WebDataset documentation for more information about accepted data and input formats. + """ + + def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shuffle_n: int = 128): + self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True) + + if isinstance(tar_filepaths, str): + # Replace '(' and '[' with '{' + brace_keys_open = ['(', '[', '<', '_OP_'] + for bkey in brace_keys_open: + if bkey in tar_filepaths: + tar_filepaths = tar_filepaths.replace(bkey, "{") + + # Replace ')' and ']' with '}' + brace_keys_close = [')', ']', '>', '_CL_'] + for bkey in brace_keys_close: + if bkey in tar_filepaths: + tar_filepaths = tar_filepaths.replace(bkey, "}") + + self.audio_dataset = ( + wd.Dataset(tar_filepaths).shuffle(shuffle_n).rename(audio='wav', key='__key__').to_tuple('audio', 'key') + ) + self.audio_iter = iter(self.audio_dataset) + + def __len__(self): + return len(self._manifest) + + def __iter__(self): + return self + + def __next__(self): + while True: + try: + audio_bytes, audio_filename = next(self.audio_iter) + + except StopIteration: + self.audio_iter = iter(self.audio_dataset) + audio_bytes, audio_filename = next(self.audio_iter) + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + + # Convert audio bytes to IO stream for processing (for SoundFile to read) + audio_file = io.BytesIO(audio_bytes) + return audio_file, file_id diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_beam_decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_beam_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..29c9960ffe14fd99bff9e9cd061c5aa4b1c3511e --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_beam_decoding.py @@ -0,0 +1,680 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import numpy as np +import torch +from tqdm import tqdm + +from nemo.collections.asr.modules import rnnt_abstract +from nemo.collections.asr.parts import rnnt_utils +from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType + + +class BeamRNNTInfer(Typing): + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": NeuralType(elements_type=HypothesisType())} + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + beam_size: int, + search_type: str = 'default', + score_norm: bool = True, + return_best_hypothesis: bool = True, + tsd_max_sym_exp_per_step: Optional[int] = 50, + alsd_max_target_len: Union[int, float] = 1.0, + nsc_max_timesteps_expansion: int = 1, + nsc_prefix_alpha: int = 1, + ): + """ + Beam Search implementation ported from ESPNet implementation - + https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py + + Sequence level beam decoding or batched-beam decoding, performed auto-repressively + depending on the search type chosen. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + + beam_size: number of beams for beam search. Must be a positive integer >= 1. + If beam size is 1, defaults to stateful greedy search. + This greedy search might result in slightly different results than + the greedy results obtained by GreedyRNNTInfer due to implementation differences. + + For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer. + + search_type: str representing the type of beam search to perform. + Must be one of ['beam', 'tsd', 'alsd']. 'nsc' is currently not supported. + + Algoritm used: + `beam` - basic beam search strategy. Larger beams generally result in better decoding, + however the time required for the search also grows steadily. + + `tsd` - time synchronous decoding. Please refer to the paper: + [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + for details on the algorithm implemented. + + Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. + For longer sequences, T is greater, and can therefore take a long time for beams to obtain + good results. This also requires greater memory to execute. + + `alsd` - alignment-length synchronous decoding. Please refer to the paper: + [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + for details on the algorithm implemented. + + Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth + factor of T + U_max, where U_max is the maximum target length expected during execution. + + Generally, T + U_max < T * max_symmetric_expansions. However, ALSD beams are non-unique, + therefore it is required to use larger beam sizes to achieve the same (or close to the same) + decoding accuracy as TSD. + + For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD. + + score_norm: bool, whether to normalize the scores of the log probabilities. + + return_best_hypothesis: bool, decides whether to return a single hypothesis (the best out of N), + or return all N hypothesis (sorted with best score first). The container class changes based + this flag - + When set to True (default), returns a single Hypothesis. + When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis. + + # The following arguments are specific to the chosen `search_type` + + tsd_max_sym_exp_per_step: Used for `search_type=tsd`. The maximum symmetric expansions allowed + per timestep during beam search. Larger values should be used to attempt decoding of longer + sequences, but this in turn increases execution time and memory usage. + + alsd_max_target_len: Used for `search_type=alsd`. The maximum expected target sequence length + during beam search. Larger values allow decoding of longer sequences at the expense of + execution time and memory. + + # The following two flags are placeholders and unused until `nsc` implementation is stabilized. + + nsc_max_timesteps_expansion: Unused int. + + nsc_prefix_alpha: Unused int. + """ + self.decoder = decoder_model + self.joint = joint_model + + self.blank = decoder_model.blank_idx + self.vocab_size = decoder_model.vocab_size + self.search_type = search_type + self.return_best_hypothesis = return_best_hypothesis + + if beam_size < 1: + raise ValueError("Beam search size cannot be less than 1!") + + self.beam_size = beam_size + self.score_norm = score_norm + + if self.beam_size == 1: + self.search_algorithm = self.greedy_search + elif search_type == "default": + self.search_algorithm = self.default_beam_search + elif search_type == "tsd": + self.search_algorithm = self.time_sync_decoding + elif search_type == "alsd": + self.search_algorithm = self.align_length_sync_decoding + elif search_type == "nsc": + raise NotImplementedError("`nsc` (Constrained Beam Search) has not been implemented.") + # self.search_algorithm = self.nsc_beam_search + else: + raise NotImplementedError( + f"The search type ({search_type}) supplied is not supported!\n" + f"Please use one of : (default, tsd, alsd, nsc)" + ) + + if tsd_max_sym_exp_per_step is None: + tsd_max_sym_exp_per_step = -1 + + if search_type in ['tsd', 'alsd', 'nsc'] and not self.decoder.blank_as_pad: + raise ValueError( + f"Search type was chosen as '{search_type}', however the decoder module provided " + f"does not support the `blank` token as a pad value. {search_type} requires " + f"the blank token as pad value support in order to perform batched beam search." + f"Please chose one of the other beam search methods, or re-train your model " + f"with this support." + ) + + self.tsd_max_symmetric_expansion_per_step = tsd_max_sym_exp_per_step + self.alsd_max_target_length = alsd_max_target_len + self.nsc_max_timesteps_expansion = nsc_max_timesteps_expansion + self.nsc_prefix_alpha = nsc_prefix_alpha + + @typecheck() + def __call__( + self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor + ) -> Union[Hypothesis, NBestHypotheses]: + """Perform general beam search. + + Args: + encoder_output: Encoded speech features (B, T_max, D_enc) + encoded_lengths: Lengths of the encoder outputs + + Returns: + Either a list containing a single Hypothesis (when `return_best_hypothesis=True`, + otherwise a list containing a single NBestHypotheses, which itself contains a list of + Hypothesis. This list is sorted such that the best hypothesis is the first element. + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.no_grad(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + + self.decoder.eval() + self.joint.eval() + + hypotheses = [] + with tqdm( + range(encoder_output.size(0)), + desc='Beam search progress:', + total=encoder_output.size(0), + unit='sample', + ) as idx_gen: + + # Freeze the decoder and joint to prevent recording of gradients + # during the beam loop. + with self.decoder.as_frozen(), self.joint.as_frozen(): + + # Decode every sample in the batch independently. + for batch_idx in idx_gen: + inseq = encoder_output[batch_idx : batch_idx + 1, :, :] # [1, T, D] + logitlen = encoded_lengths[batch_idx] + + # Execute the specific search strategy + nbest_hyps = self.search_algorithm(inseq, logitlen) # sorted list of hypothesis + + # Pack the result + if self.return_best_hypothesis: + best_hypothesis = nbest_hyps[0] # type: Hypothesis + else: + best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses + hypotheses.append(best_hypothesis) + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (hypotheses,) + + def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Sort hypotheses by score or score given sequence length. + + Args: + hyps: list of hypotheses + + Return: + hyps: sorted list of hypotheses + """ + if self.score_norm: + return sorted(hyps, key=lambda x: x.score / len(x.y_sequence), reverse=True) + else: + return sorted(hyps, key=lambda x: x.score, reverse=True) + + def greedy_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + """Greedy search implementation for transducer. + Generic case when beam size = 1. Results might differ slightly due to implementation details + as compared to `GreedyRNNTInfer` and `GreedyBatchRNNTInfer`. + + Args: + h: Encoded speech features (1, T_max, D_enc) + + Returns: + hyp: 1-best decoding results + """ + # Initialize zero state vectors + dec_state = self.decoder.initialize_state(h) + + # Construct initial hypothesis + hyp = Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state) + cache = {} + + # Initialize state and first token + y, state, _ = self.decoder.score_hypothesis(hyp, cache) + + for i in range(int(encoded_lengths)): + hi = h[:, i : i + 1, :] # [1, 1, D] + + not_blank = True + symbols_added = 0 + + while not_blank: + ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1] + ytu = ytu[0, 0, 0, :] # [V + 1] + + # max() requires float + if ytu.dtype != torch.float32: + ytu = ytu.float() + + logp, pred = torch.max(ytu, dim=-1) # [1, 1] + pred = pred.item() + + if pred == self.blank: + not_blank = False + else: + # Update state and current sequence + hyp.y_sequence.append(int(pred)) + hyp.score += float(logp) + hyp.dec_state = state + + # Compute next state and token + y, state, _ = self.decoder.score_hypothesis(hyp, cache) + symbols_added += 1 + + return [hyp] + + def default_beam_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + """Beam search implementation. + + Args: + x: Encoded speech features (1, T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + """ + # Initialize states + beam = min(self.beam_size, self.vocab_size) + beam_k = min(beam, (self.vocab_size - 1)) + blank_tensor = torch.tensor([self.blank], device=h.device, dtype=torch.long) + + # Precompute some constants for blank position + ids = list(range(self.vocab_size + 1)) + ids.remove(self.blank) + + # Used when blank token is first vs last token + if self.blank == 0: + index_incr = 1 + else: + index_incr = 0 + + # Initialize zero vector states + dec_state = self.decoder.initialize_state(h) + + # Initialize first hypothesis for the beam (blank) + kept_hyps = [Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state)] + cache = {} + + for i in range(int(encoded_lengths)): + hi = h[:, i : i + 1, :] # [1, 1, D] + hyps = kept_hyps + kept_hyps = [] + + while True: + max_hyp = max(hyps, key=lambda x: x.score) + hyps.remove(max_hyp) + + # update decoder state and get next score + y, state, lm_tokens = self.decoder.score_hypothesis(max_hyp, cache) # [1, 1, D] + + # get next token + ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1] + ytu = ytu[0, 0, 0, :] # [V + 1] + + # remove blank token before top k + top_k = ytu[ids].topk(beam_k, dim=-1) + + # Two possible steps - blank token or non-blank token predicted + ytu = ( + torch.cat((top_k[0], ytu[self.blank].unsqueeze(0))), + torch.cat((top_k[1] + index_incr, blank_tensor)), + ) + + # for each possible step + for logp, k in zip(*ytu): + # construct hypothesis for step + new_hyp = Hypothesis( + score=(max_hyp.score + float(logp)), + y_sequence=max_hyp.y_sequence[:], + dec_state=max_hyp.dec_state, + lm_state=max_hyp.lm_state, + ) + + # if current token is blank, dont update sequence, just store the current hypothesis + if k == self.blank: + kept_hyps.append(new_hyp) + else: + # if non-blank token was predicted, update state and sequence and then search more hypothesis + new_hyp.dec_state = state + new_hyp.y_sequence.append(int(k)) + + hyps.append(new_hyp) + + # keep those hypothesis that have scores greater than next search generation + hyps_max = float(max(hyps, key=lambda x: x.score).score) + kept_most_prob = sorted([hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score,) + + # If enough hypothesis have scores greater than next search generation, + # stop beam search. + if len(kept_most_prob) >= beam: + kept_hyps = kept_most_prob + break + + return self.sort_nbest(kept_hyps) + + def time_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + """Time synchronous beam search implementation. + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + h: Encoded speech features (1, T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + """ + # Precompute some constants for blank position + ids = list(range(self.vocab_size + 1)) + ids.remove(self.blank) + + # Used when blank token is first vs last token + if self.blank == 0: + index_incr = 1 + else: + index_incr = 0 + + # prepare the batched beam states + beam = min(self.beam_size, self.vocab_size) + beam_state = self.decoder.initialize_state( + torch.zeros(beam, device=h.device, dtype=h.dtype) + ) # [L, B, H], [L, B, H] (for LSTMs) + + # Initialize first hypothesis for the beam (blank) + B = [Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0))] + cache = {} + + for i in range(int(encoded_lengths)): + hi = h[:, i : i + 1, :] + + # Update caches + A = [] + C = B + + h_enc = hi + + # For a limited number of symmetric expansions per timestep "i" + for v in range(self.tsd_max_symmetric_expansion_per_step): + D = [] + + # Decode a batch of beam states and scores + beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(C, cache, beam_state) + + # Extract the log probabilities and the predicted tokens + beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B, 1, 1, V + 1] + beam_logp = beam_logp[:, 0, 0, :] # [B, V + 1] + beam_topk = beam_logp[:, ids].topk(beam, dim=-1) + + seq_A = [h.y_sequence for h in A] + + for j, hyp in enumerate(C): + # create a new hypothesis in A + if hyp.y_sequence not in seq_A: + # If the sequence is not in seq_A, add it as the blank token + # In this step, we dont add a token but simply update score + A.append( + Hypothesis( + score=(hyp.score + float(beam_logp[j, self.blank])), + y_sequence=hyp.y_sequence[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + ) + else: + # merge the existing blank hypothesis score with current score. + dict_pos = seq_A.index(hyp.y_sequence) + + A[dict_pos].score = np.logaddexp( + A[dict_pos].score, (hyp.score + float(beam_logp[j, self.blank])) + ) + + if v < self.tsd_max_symmetric_expansion_per_step: + for j, hyp in enumerate(C): + # for each current hypothesis j + # extract the top token score and top token id for the jth hypothesis + for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr): + # create new hypothesis and store in D + # Note: This loop does *not* include the blank token! + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + y_sequence=(hyp.y_sequence + [int(k)]), + dec_state=self.decoder.batch_select_state(beam_state, j), + lm_state=hyp.lm_state, + ) + + D.append(new_hyp) + + # Prune beam + C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] + + # Prune beam + B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] + + return self.sort_nbest(B) + + def align_length_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + """Alignment-length synchronous beam search implementation. + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + h: Encoded speech features (1, T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + """ + # Precompute some constants for blank position + ids = list(range(self.vocab_size + 1)) + ids.remove(self.blank) + + # Used when blank token is first vs last token + if self.blank == 0: + index_incr = 1 + else: + index_incr = 0 + + # prepare the batched beam states + beam = min(self.beam_size, self.vocab_size) + + h = h[0] # [T, D] + h_length = int(encoded_lengths) + beam_state = self.decoder.initialize_state( + torch.zeros(beam, device=h.device, dtype=h.dtype) + ) # [L, B, H], [L, B, H] for LSTMS + + # compute u_max as either a specific static limit, + # or a multiple of current `h_length` dynamically. + if type(self.alsd_max_target_length) == float: + u_max = int(self.alsd_max_target_length * h_length) + else: + u_max = int(self.alsd_max_target_length) + + # Initialize first hypothesis for the beam (blank) + B = [Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0))] + + final = [] + cache = {} + + # ALSD runs for T + U_max steps + for i in range(h_length + u_max): + # Update caches + A = [] + B_ = [] + h_states = [] + + # preserve the list of batch indices which are added into the list + # and those which are removed from the list + # This is necessary to perform state updates in the correct batch indices later + batch_ids = list(range(len(B))) # initialize as a list of all batch ids + batch_removal_ids = [] # update with sample ids which are removed + + for bid, hyp in enumerate(B): + u = len(hyp.y_sequence) - 1 + t = i - u + 1 + + if t > (h_length - 1): + batch_removal_ids.append(bid) + continue + + B_.append(hyp) + h_states.append((t, h[t])) + + if B_: + # Compute the subset of batch ids which were *not* removed from the list above + sub_batch_ids = None + if len(B_) != beam: + sub_batch_ids = batch_ids + for id in batch_removal_ids: + # sub_batch_ids contains list of ids *that were not removed* + sub_batch_ids.remove(id) + + # extract the states of the sub batch only. + beam_state_ = [beam_state[state_id][:, sub_batch_ids, :] for state_id in range(len(beam_state))] + else: + # If entire batch was used (none were removed), simply take all the states + beam_state_ = beam_state + + # Decode a batch/sub-batch of beam states and scores + beam_y, beam_state_, beam_lm_tokens = self.decoder.batch_score_hypothesis(B_, cache, beam_state_) + + # If only a subset of batch ids were updated (some were removed) + if sub_batch_ids is not None: + # For each state in the RNN (2 for LSTM) + for state_id in range(len(beam_state)): + # Update the current batch states with the sub-batch states (in the correct indices) + # These indices are specified by sub_batch_ids, the ids of samples which were updated. + beam_state[state_id][:, sub_batch_ids, :] = beam_state_[state_id][...] + else: + # If entire batch was updated, simply update all the states + beam_state = beam_state_ + + # h_states = list of [t, h[t]] + # so h[1] here is a h[t] of shape [D] + # Simply stack all of the h[t] within the sub_batch/batch (T <= beam) + h_enc = torch.stack([h[1] for h in h_states]) # [T=beam, D] + h_enc = h_enc.unsqueeze(1) # [B=beam, T=1, D]; batch over the beams + + # Extract the log probabilities and the predicted tokens + beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B=beam, 1, 1, V + 1] + beam_logp = beam_logp[:, 0, 0, :] # [B=beam, V + 1] + beam_topk = beam_logp[:, ids].topk(beam, dim=-1) + + for j, hyp in enumerate(B_): + # For all updated samples in the batch, add it as the blank token + # In this step, we dont add a token but simply update score + new_hyp = Hypothesis( + score=(hyp.score + float(beam_logp[j, self.blank])), + y_sequence=hyp.y_sequence[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + + # Add blank prediction to A + A.append(new_hyp) + + # If the prediction "timestep" t has reached the length of the input sequence + # we can add it to the "finished" hypothesis list. + if h_states[j][0] == (h_length - 1): + final.append(new_hyp) + + # Here, we carefully select the indices of the states that we want to preserve + # for the next token (non-blank) update. + if sub_batch_ids is not None: + h_states_idx = sub_batch_ids[j] + else: + h_states_idx = j + + # for each current hypothesis j + # extract the top token score and top token id for the jth hypothesis + for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr): + # create new hypothesis and store in A + # Note: This loop does *not* include the blank token! + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + y_sequence=(hyp.y_sequence[:] + [int(k)]), + dec_state=self.decoder.batch_select_state(beam_state, h_states_idx), + lm_state=hyp.lm_state, + ) + + A.append(new_hyp) + + # Prune and recombine same hypothesis + # This may cause next beam to be smaller than max beam size + # Therefore larger beam sizes may be required for better decoding. + B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] + B = self.recombine_hypotheses(B) + + # If B_ is empty list, then we may be able to early exit + elif len(batch_ids) == len(batch_removal_ids): + break + + if final: + return self.sort_nbest(final) + else: + return B + + def recombine_hypotheses(self, hypotheses: List[Hypothesis]) -> List[Hypothesis]: + """Recombine hypotheses with equivalent output sequence. + + Args: + hypotheses (list): list of hypotheses + + Returns: + final (list): list of recombined hypotheses + """ + final = [] + + for hyp in hypotheses: + seq_final = [f.y_sequence for f in final if f.y_sequence] + + if hyp.y_sequence in seq_final: + seq_pos = seq_final.index(hyp.y_sequence) + + final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score) + else: + final.append(hyp) + + return hypotheses diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_greedy_decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_greedy_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec12457ea9ff78f8dd7d63aec40427b48365e3b --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_greedy_decoding.py @@ -0,0 +1,565 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch + +from nemo.collections.asr.modules import rnnt_abstract +from nemo.collections.asr.parts import rnnt_utils +from nemo.collections.common.parts.rnn import label_collate +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType + + +class _GreedyRNNTInfer(Typing): + """A greedy transducer decoder. + + Provides a common abstraction for sample level and batch level greedy decoding. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + max_symbols_per_step: Optional int. The maximum number of symbols that can be added + to a sequence in a single time step; if set to None then there is + no limit. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": NeuralType(elements_type=HypothesisType())} + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + max_symbols_per_step: Optional[int] = None, + ): + super().__init__() + self.decoder = decoder_model + self.joint = joint_model + + self._blank_index = blank_index + self._SOS = blank_index # Start of single index + self.max_symbols = max_symbols_per_step + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.no_grad() + def _pred_step( + self, + label: Union[torch.Tensor, int], + hidden: Optional[torch.Tensor], + add_sos: bool = False, + batch_size: Optional[int] = None, + ) -> (torch.Tensor, torch.Tensor): + """ + Common prediction step based on the AbstractRNNTDecoder implementation. + + Args: + label: (int/torch.Tensor): Label or "Start-of-Signal" token. + hidden: (Optional torch.Tensor): RNN State vector + add_sos (bool): Whether to add a zero vector at the begging as "start of sentence" token. + batch_size: Batch size of the output tensor. + + Returns: + g: (B, U, H) if add_sos is false, else (B, U + 1, H) + hid: (h, c) where h is the final sequence hidden state and c is + the final cell state: + h (tensor), shape (L, B, H) + c (tensor), shape (L, B, H) + """ + if isinstance(label, torch.Tensor): + # label: [batch, 1] + if label.dtype != torch.long: + label = label.long() + + else: + # Label is an integer + if label == self._SOS: + return self.decoder.predict(None, hidden, add_sos=add_sos, batch_size=batch_size) + + label = label_collate([[label]]) + + # output: [B, 1, K] + return self.decoder.predict(label, hidden, add_sos=add_sos, batch_size=batch_size) + + def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None): + """ + Common joint step based on AbstractRNNTJoint implementation. + + Args: + enc: Output of the Encoder model. A torch.Tensor of shape [B, 1, H1] + pred: Output of the Decoder model. A torch.Tensor of shape [B, 1, H2] + log_normalize: Whether to log normalize or not. None will log normalize only for CPU. + + Returns: + logits of shape (B, T=1, U=1, V + 1) + """ + with torch.no_grad(): + logits = self.joint.joint(enc, pred) + + if log_normalize is None: + if not logits.is_cuda: # Use log softmax only if on CPU + logits = logits.log_softmax(dim=len(logits.shape) - 1) + else: + if log_normalize: + logits = logits.log_softmax(dim=len(logits.shape) - 1) + + return logits + + +class GreedyRNNTInfer(_GreedyRNNTInfer): + """A greedy transducer decoder. + + Sequence level greedy decoding, performed auto-repressively. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + max_symbols_per_step: Optional int. The maximum number of symbols that can be added + to a sequence in a single time step; if set to None then there is + no limit. + """ + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + max_symbols_per_step: Optional[int] = None, + ): + super().__init__( + decoder_model=decoder_model, + joint_model=joint_model, + blank_index=blank_index, + max_symbols_per_step=max_symbols_per_step, + ) + + @typecheck() + def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + encoder_output: A tensor of size (batch, features, timesteps). + encoded_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.no_grad(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + + self.decoder.eval() + self.joint.eval() + + hypotheses = [] + # Process each sequence independently + with self.decoder.as_frozen(), self.joint.as_frozen(): + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] + sentence = self._greedy_decode(inseq, logitlen) + hypotheses.append(sentence) + + # Pack results into Hypotheses + packed_result = [ + rnnt_utils.Hypothesis(y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0) + for sent in hypotheses + ] + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (packed_result,) + + @torch.no_grad() + def _greedy_decode(self, x: torch.Tensor, out_len: torch.Tensor): + # x: [T, 1, D] + # out_len: [seq_len] + + # Initialize blank state and empty label set + hidden = None + label = [] + + # For timestep t in X_t + for time_idx in range(out_len): + # Extract encoder embedding at timestep t + # f = x[time_idx, :, :].unsqueeze(0) # [1, 1, D] + f = x.narrow(dim=0, start=time_idx, length=1) + + # Setup exit flags and counter + not_blank = True + symbols_added = 0 + + # While blank is not predicted, or we dont run out of max symbols per timestep + while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): + # In the first timestep, we initialize the network with RNNT Blank + # In later timesteps, we provide previous predicted label as input. + last_label = self._SOS if label == [] else label[-1] + + # Perform prediction network and joint network steps. + g, hidden_prime = self._pred_step(last_label, hidden) + logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :] + + del g + + # torch.max(0) op doesnt exist for FP 16. + if logp.dtype != torch.float32: + logp = logp.float() + + # get index k, of max prob + v, k = logp.max(0) + k = k.item() # K is the label at timestep t_s in inner loop, s >= 0. + + del logp + + # If blank token is predicted, exit inner loop, move onto next timestep t + if k == self._blank_index: + not_blank = False + else: + # Append token to label set, update RNN state. + label.append(k) + hidden = hidden_prime + + # Increment token counter. + symbols_added += 1 + + return label + + +class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): + """A batch level greedy transducer decoder. + + Batch level greedy decoding, performed auto-repressively. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + max_symbols_per_step: Optional int. The maximum number of symbols that can be added + to a sequence in a single time step; if set to None then there is + no limit. + """ + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + max_symbols_per_step: Optional[int] = None, + ): + super().__init__( + decoder_model=decoder_model, + joint_model=joint_model, + blank_index=blank_index, + max_symbols_per_step=max_symbols_per_step, + ) + + # Depending on availability of `blank_as_pad` support + # switch between more efficient batch decoding technique + if self.decoder.blank_as_pad: + self._greedy_decode = self._greedy_decode_blank_as_pad + else: + self._greedy_decode = self._greedy_decode_masked + + @typecheck() + def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + encoder_output: A tensor of size (batch, features, timesteps). + encoded_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.no_grad(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + logitlen = encoded_lengths + + self.decoder.eval() + self.joint.eval() + + with self.decoder.as_frozen(), self.joint.as_frozen(): + inseq = encoder_output # [B, T, D] + hypotheses = self._greedy_decode(inseq, logitlen, device=inseq.device) + + # Pack the hypotheses results + packed_result = [ + rnnt_utils.Hypothesis(y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0) + for sent in hypotheses + ] + + del hypotheses + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (packed_result,) + + def _greedy_decode_blank_as_pad(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device): + with torch.no_grad(): + # x: [B, T, D] + # out_len: [B] + # device: torch.device + + # Initialize state + hidden = None + batchsize = x.shape[0] + + # Output string buffer + label = [[] for _ in range(batchsize)] + + # Last Label buffer + Last Label without blank buffer + # batch level equivalent of the last_label + last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) + + # Mask buffers + blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) + + # Get max sequence length + max_out_len = out_len.max() + + for time_idx in range(max_out_len): + f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] + + # Prepare t timestamp batch variables + not_blank = True + symbols_added = 0 + + # Reset blank mask + blank_mask.mul_(False) + + # Update blank mask with time mask + # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) + # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len + blank_mask = time_idx >= out_len + + # Start inner loop + while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): + + # Batch prediction and joint network steps + # If very first prediction step, submit SOS tag (blank) to pred_step. + # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state + if time_idx == 0 and symbols_added == 0: + g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) + else: + # Perform batch step prediction of decoder, getting new states and scores ("g") + g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize) + + # Batched joint step - Output = [B, V + 1] + logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :] + + if logp.dtype != torch.float32: + logp = logp.float() + + # Get index k, of max prob for batch + v, k = logp.max(1) + del v, g, logp + + # Update blank mask with current predicted blanks + # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) + k_is_blank = k == self._blank_index + blank_mask |= k_is_blank + + del k_is_blank + + # If all samples predict / have predicted prior blanks, exit loop early + # This is equivalent to if single sample predicted k + if blank_mask.all(): + not_blank = False + else: + # Collect batch indices where blanks occurred now/past + blank_indices = [] + if hidden is not None: + blank_indices = (blank_mask == 1).nonzero(as_tuple=False) + + # Recover prior state for all samples which predicted blank now/past + if hidden is not None: + # LSTM has 2 states + for state_id in range(len(hidden)): + hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :] + + # Recover prior predicted label for all samples which predicted blank now/past + k[blank_indices] = last_label[blank_indices, 0] + + # Update new label and hidden state for next iteration + last_label = k.clone().view(-1, 1) + hidden = hidden_prime + + # Update predicted labels, accounting for time mask + # If blank was predicted even once, now or in the past, + # Force the current predicted label to also be blank + # This ensures that blanks propogate across all timesteps + # once they have occured (normally stopping condition of sample level loop). + for kidx, ki in enumerate(k): + if blank_mask[kidx] == 0: + label[kidx].append(ki) + + symbols_added += 1 + + return label + + @torch.no_grad() + def _greedy_decode_masked(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device): + # x: [B, T, D] + # out_len: [B] + # device: torch.device + + # Initialize state + hidden = None + batchsize = x.shape[0] + + # Output string buffer + label = [[] for _ in range(batchsize)] + + # Last Label buffer + Last Label without blank buffer + # batch level equivalent of the last_label + last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) + last_label_without_blank = last_label.clone() + + # Mask buffers + blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) + + # Get max sequence length + max_out_len = out_len.max() + for time_idx in range(max_out_len): + f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] + + # Prepare t timestamp batch variables + not_blank = True + symbols_added = 0 + + # Reset blank mask + blank_mask.mul_(False) + + # Update blank mask with time mask + # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) + # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len + blank_mask = time_idx >= out_len + + # Start inner loop + while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): + # Batch prediction and joint network steps + # If very first prediction step, submit SOS tag (blank) to pred_step. + # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state + if time_idx == 0 and symbols_added == 0: + g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) + else: + # Set a dummy label for the blank value + # This value will be overwritten by "blank" again the last label update below + # This is done as vocabulary of prediction network does not contain "blank" token of RNNT + last_label_without_blank_mask = last_label == self._blank_index + last_label_without_blank[last_label_without_blank_mask] = 0 # temp change of label + last_label_without_blank[~last_label_without_blank_mask] = last_label[ + ~last_label_without_blank_mask + ] + + # Perform batch step prediction of decoder, getting new states and scores ("g") + g, hidden_prime = self._pred_step(last_label_without_blank, hidden, batch_size=batchsize) + + # Batched joint step - Output = [B, V + 1] + logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :] + + if logp.dtype != torch.float32: + logp = logp.float() + + # Get index k, of max prob for batch + v, k = logp.max(1) + del v, g, logp + + # Update blank mask with current predicted blanks + # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) + k_is_blank = k == self._blank_index + blank_mask.bitwise_or_(k_is_blank) + + # If all samples predict / have predicted prior blanks, exit loop early + # This is equivalent to if single sample predicted k + if blank_mask.all(): + not_blank = False + else: + # Collect batch indices where blanks occurred now/past + blank_indices = [] + if hidden is not None: + blank_indices = (blank_mask == 1).nonzero(as_tuple=False) + + # Recover prior state for all samples which predicted blank now/past + if hidden is not None: + # LSTM has 2 states + for state_id in range(len(hidden)): + hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :] + + # Recover prior predicted label for all samples which predicted blank now/past + k[blank_indices] = last_label[blank_indices, 0] + + # Update new label and hidden state for next iteration + last_label = k.view(-1, 1) + hidden = hidden_prime + + # Update predicted labels, accounting for time mask + # If blank was predicted even once, now or in the past, + # Force the current predicted label to also be blank + # This ensures that blanks propogate across all timesteps + # once they have occured (normally stopping condition of sample level loop). + for kidx, ki in enumerate(k): + if blank_mask[kidx] == 0: + label[kidx].append(ki) + + symbols_added += 1 + + return label diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_utils.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..55cd53bd89ad8b1dfd89d3297880a3331655c349 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch + + +@dataclass +class Hypothesis: + """Hypothesis class for beam search algorithms. + + score: A float score obtained from an AbstractRNNTDecoder module's score_hypothesis method. + + y_sequence: Either a sequence of integer ids pointing to some vocabulary, or a packed torch.Tensor + behaving in the same manner. dtype must be torch.Long in the latter case. + + dec_state: A list (or list of list) of LSTM-RNN decoder states. Can be None. + + y: (Unused) A list of torch.Tensors representing the list of hypotheses. + + lm_state: (Unused) A dictionary state cache used by an external Language Model. + + lm_scores: (Unused) Score of the external Language Model. + """ + + score: float + y_sequence: Union[List[int], torch.Tensor] + dec_state: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor]]] = None + y: List[torch.tensor] = None + lm_state: Union[Dict[str, Any], List[Any]] = None + lm_scores: torch.Tensor = None + + +@dataclass +class NBestHypotheses: + """List of N best hypotheses""" + + n_best_hypotheses: Optional[List[Hypothesis]] diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/segment.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/segment.py new file mode 100644 index 0000000000000000000000000000000000000000..85b4c38216ea5665982faf9c963a37e0c1fe7c57 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/segment.py @@ -0,0 +1,223 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (c) 2018 Ryan Leary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# This file contains code artifacts adapted from https://github.com/ryanleary/patter + +import random + +import librosa +import numpy as np +import soundfile as sf + + +class AudioSegment(object): + """Monaural audio segment abstraction. + :param samples: Audio samples [num_samples x num_channels]. + :type samples: ndarray.float32 + :param sample_rate: Audio sample rate. + :type sample_rate: int + :raises TypeError: If the sample data type is not float or int. + """ + + def __init__(self, samples, sample_rate, target_sr=None, trim=False, trim_db=60, orig_sr=None): + """Create audio segment from samples. + Samples are convert float32 internally, with int scaled to [-1, 1]. + """ + samples = self._convert_samples_to_float32(samples) + if target_sr is not None and target_sr != sample_rate: + samples = librosa.core.resample(samples, sample_rate, target_sr) + sample_rate = target_sr + if trim: + samples, _ = librosa.effects.trim(samples, trim_db) + self._samples = samples + self._sample_rate = sample_rate + if self._samples.ndim >= 2: + self._samples = np.mean(self._samples, 1) + + self._orig_sr = orig_sr if orig_sr is not None else sample_rate + + def __eq__(self, other): + """Return whether two objects are equal.""" + if type(other) is not type(self): + return False + if self._sample_rate != other._sample_rate: + return False + if self._samples.shape != other._samples.shape: + return False + if np.any(self.samples != other._samples): + return False + return True + + def __ne__(self, other): + """Return whether two objects are unequal.""" + return not self.__eq__(other) + + def __str__(self): + """Return human-readable representation of segment.""" + return "%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, rms=%.2fdB" % ( + type(self), + self.num_samples, + self.sample_rate, + self.duration, + self.rms_db, + ) + + @staticmethod + def _convert_samples_to_float32(samples): + """Convert sample type to float32. + Audio sample type is usually integer or float-point. + Integers will be scaled to [-1, 1] in float32. + """ + float32_samples = samples.astype('float32') + if samples.dtype in np.sctypes['int']: + bits = np.iinfo(samples.dtype).bits + float32_samples *= 1.0 / 2 ** (bits - 1) + elif samples.dtype in np.sctypes['float']: + pass + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return float32_samples + + @classmethod + def from_file( + cls, audio_file, target_sr=None, int_values=False, offset=0, duration=0, trim=False, orig_sr=None, + ): + """ + Load a file supported by librosa and return as an AudioSegment. + :param audio_file: path of file to load + :param target_sr: the desired sample rate + :param int_values: if true, load samples as 32-bit integers + :param offset: offset in seconds when loading audio + :param duration: duration in seconds when loading audio + :return: numpy array of samples + """ + with sf.SoundFile(audio_file, 'r') as f: + dtype = 'int32' if int_values else 'float32' + sample_rate = f.samplerate + if offset > 0: + f.seek(int(offset * sample_rate)) + if duration > 0: + samples = f.read(int(duration * sample_rate), dtype=dtype) + else: + samples = f.read(dtype=dtype) + + samples = samples.transpose() + return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr) + + @classmethod + def segment_from_file(cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None): + """Grabs n_segments number of samples from audio_file randomly from the + file as opposed to at a specified offset. + + Note that audio_file can be either the file path, or a file-like object. + """ + with sf.SoundFile(audio_file, 'r') as f: + sample_rate = f.samplerate + if n_segments > 0 and len(f) > n_segments: + max_audio_start = len(f) - n_segments + audio_start = random.randint(0, max_audio_start) + f.seek(audio_start) + samples = f.read(n_segments, dtype='float32') + else: + samples = f.read(dtype='float32') + + samples = samples.transpose() + return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr) + + @property + def samples(self): + return self._samples.copy() + + @property + def sample_rate(self): + return self._sample_rate + + @property + def num_samples(self): + return self._samples.shape[0] + + @property + def duration(self): + return self._samples.shape[0] / float(self._sample_rate) + + @property + def rms_db(self): + mean_square = np.mean(self._samples ** 2) + return 10 * np.log10(mean_square) + + @property + def orig_sr(self): + return self._orig_sr + + def gain_db(self, gain): + self._samples *= 10.0 ** (gain / 20.0) + + def pad(self, pad_size, symmetric=False): + """Add zero padding to the sample. The pad size is given in number + of samples. + If symmetric=True, `pad_size` will be added to both sides. If false, + `pad_size` + zeros will be added only to the end. + """ + self._samples = np.pad(self._samples, (pad_size if symmetric else 0, pad_size), mode='constant',) + + def subsegment(self, start_time=None, end_time=None): + """Cut the AudioSegment between given boundaries. + Note that this is an in-place transformation. + :param start_time: Beginning of subsegment in seconds. + :type start_time: float + :param end_time: End of subsegment in seconds. + :type end_time: float + :raise ValueError: If start_time or end_time is incorrectly set, + e.g. out + of bounds in time. + """ + start_time = 0.0 if start_time is None else start_time + end_time = self.duration if end_time is None else end_time + if start_time < 0.0: + start_time = self.duration + start_time + if end_time < 0.0: + end_time = self.duration + end_time + if start_time < 0.0: + raise ValueError("The slice start position (%f s) is out of bounds." % start_time) + if end_time < 0.0: + raise ValueError("The slice end position (%f s) is out of bounds." % end_time) + if start_time > end_time: + raise ValueError( + "The slice start position (%f s) is later than the end position (%f s)." % (start_time, end_time) + ) + if end_time > self.duration: + raise ValueError("The slice end position (%f s) is out of bounds (> %f s)" % (end_time, self.duration)) + start_sample = int(round(start_time * self._sample_rate)) + end_sample = int(round(end_time * self._sample_rate)) + self._samples = self._samples[start_sample:end_sample] diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/simple_wer_v2.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/simple_wer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..172d2ef71ac0759a041595cfdc1d3224ae3d177e --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/simple_wer_v2.py @@ -0,0 +1,454 @@ +# Lint as: python2, python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The new version script to evalute the word error rate (WER) for ASR tasks. + +Tensorflow and Lingvo are not required to run this script. + +Example of Usage: + +a) `python simple_wer_v2.py file_hypothesis file_reference` +b) `python simple_wer_v2.py file_hypothesis file_reference file_keyphrases` + +where `file_hypothesis` is the filename for hypothesis text, +`file_reference` is the filename for reference text, and +`file_keyphrases` is the optional filename for important phrases +(one phrase per line). + +Note that the program will also generate a html to diagnose the errors, +and the html filename is `{$file_hypothesis}_diagnois.html`. + +Another way is to use this file as a stand-alone library, by calling class +SimpleWER with the following member functions: + +- AddHypRef(hyp, ref): Updates the evaluation for each (hyp,ref) pair. +- GetWER(): Computes word error rate (WER) for all the added hyp-ref pairs. +- GetSummaries(): Generates strings to summarize word and key phrase errors. +- GetKeyPhraseStats(): Measures stats for key phrases. + Stats include: + (1) Jaccard similarity: https://en.wikipedia.org/wiki/Jaccard_index. + (2) F1 score: https://en.wikipedia.org/wiki/Precision_and_recall. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import sys +from six.moves import range + + +def TxtPreprocess(txt): + """Preprocess text before WER caculation.""" + + # Lowercase, remove \t and new line. + txt = re.sub(r'[\t\n]', ' ', txt.lower()) + + # Remove punctuation before space. + txt = re.sub(r'[,.\?!]+ ', ' ', txt) + + # Remove punctuation before end. + txt = re.sub(r'[,.\?!]+$', ' ', txt) + + # Remove punctuation after space. + txt = re.sub(r' [,.\?!]+', ' ', txt) + + # Remove quotes, [, ], ( and ). + txt = re.sub(r'["\(\)\[\]]', '', txt) + + # Remove extra space. + txt = re.sub(' +', ' ', txt.strip()) + + return txt + + +def RemoveCommentTxtPreprocess(txt): + """Preprocess text and remove comments in the brancket, such as [comments].""" + + # Remove comments surrounded by box brackets: + txt = re.sub(r'\[\w+\]', '', txt) + + return TxtPreprocess(txt) + + +def HighlightAlignedHtml(hyp, ref, err_type): + """Generate a html element to highlight the difference between hyp and ref. + + Args: + hyp: Hypothesis string. + ref: Reference string. + err_type: one of 'none', 'sub', 'del', 'ins'. + + Returns: + a html string where disagreements are highlighted. + Note `hyp` is highlighted in green, and marked with + `ref` is highlighted in yellow. If you want html with nother styles, + consider to write your own function. + + Raises: + ValueError: if err_type is not among ['none', 'sub', 'del', 'ins']. + or if when err_type == 'none', hyp != ref + """ + + highlighted_html = '' + if err_type == 'none': + if hyp != ref: + raise ValueError('hyp (%s) does not match ref (%s) for none error' % + (hyp, ref)) + highlighted_html += '%s ' % hyp + + elif err_type == 'sub': + highlighted_html += """ + %s + %s """ % (hyp, ref) + + elif err_type == 'del': + highlighted_html += """ + %s """ % ( + ref) + + elif err_type == 'ins': + highlighted_html += """ + %s """ % ( + hyp) + + else: + raise ValueError('unknown err_type ' + err_type) + + return highlighted_html + + +def ComputeEditDistanceMatrix(hyp_words, ref_words): + """Compute edit distance between two list of strings. + + Args: + hyp_words: the list of words in the hypothesis sentence + ref_words: the list of words in the reference sentence + + Returns: + Edit distance matrix (in the format of list of lists), where the first + index is the reference and the second index is the hypothesis. + """ + reference_length_plus = len(ref_words) + 1 + hypothesis_length_plus = len(hyp_words) + 1 + edit_dist_mat = [[]] * reference_length_plus + + # Initialization. + for i in range(reference_length_plus): + edit_dist_mat[i] = [0] * hypothesis_length_plus + for j in range(hypothesis_length_plus): + if i == 0: + edit_dist_mat[0][j] = j + elif j == 0: + edit_dist_mat[i][0] = i + + # Do dynamic programming. + for i in range(1, reference_length_plus): + for j in range(1, hypothesis_length_plus): + if ref_words[i - 1] == hyp_words[j - 1]: + edit_dist_mat[i][j] = edit_dist_mat[i - 1][j - 1] + else: + tmp0 = edit_dist_mat[i - 1][j - 1] + 1 + tmp1 = edit_dist_mat[i][j - 1] + 1 + tmp2 = edit_dist_mat[i - 1][j] + 1 + edit_dist_mat[i][j] = min(tmp0, tmp1, tmp2) + + return edit_dist_mat + + +class SimpleWER(object): + """Compute word error rates after the alignment. + + Attributes: + key_phrases: list of important phrases. + aligned_htmls: list of diagnois htmls, each of which corresponding to a pair + of hypothesis and reference. + hyp_keyphrase_counts: dict. `hyp_keyphrase_counts[w]` counts how often a key + phrases `w` appear in the hypotheses. + ref_keyphrase_counts: dict. `ref_keyphrase_counts[w]` counts how often a key + phrases `w` appear in the references. + matched_keyphrase_counts: dict. `matched_keyphrase_counts[w]` counts how + often a key phrase `w` appear in the aligned transcripts when the + reference and hyp_keyphrase match. + wer_info: dict with four keys: 'sub' (substitution error), 'ins' (insersion + error), 'del' (deletion error), 'nw' (number of words). We can use + wer_info to compute word error rate (WER) as + (wer_info['sub']+wer_info['ins']+wer_info['del'])*100.0/wer_info['nw'] + """ + + def __init__(self, + key_phrases=None, + html_handler=HighlightAlignedHtml, + preprocess_handler=RemoveCommentTxtPreprocess): + """Initialize SimpleWER object. + + Args: + key_phrases: list of strings as important phrases. If key_phrases is + None, no key_phrases related metric will be computed. + html_handler: function to generate a string with html tags. + preprocess_handler: function to my_preprocess text before computing WER. + """ + self._preprocess_handler = preprocess_handler + self._html_handler = html_handler + self.key_phrases = key_phrases + self.aligned_htmls = [] + self.wer_info = {'sub': 0, 'ins': 0, 'del': 0, 'nw': 0} + if key_phrases: + # Pre-process key_phrase list + if self._preprocess_handler: + self.key_phrases = \ + [self._preprocess_handler(k) for k in self.key_phrases] + + # Init keyphrase_counts for every key phrase + self.ref_keyphrase_counts = {} + self.hyp_keyphrase_counts = {} + self.matched_keyphrase_counts = {} + for k in self.key_phrases: + self.ref_keyphrase_counts[k] = 0 + self.hyp_keyphrase_counts[k] = 0 + self.matched_keyphrase_counts[k] = 0 + else: + self.ref_keyphrase_counts = None + self.hyp_keyphrase_counts = None + self.matched_keyphrase_counts = None + + def AddHypRef(self, hypothesis, reference): + """Update WER when adding one pair of strings: (hypothesis, reference). + + Args: + hypothesis: Hypothesis string. + reference: Reference string. + + Raises: + ValueError: when the program fails to parse edit distance matrix. + """ + if self._preprocess_handler: + hypothesis = self._preprocess_handler(hypothesis) + reference = self._preprocess_handler(reference) + + # Compute edit distance. + hyp_words = hypothesis.split() + ref_words = reference.split() + distmat = ComputeEditDistanceMatrix(hyp_words, ref_words) + + # Back trace, to distinguish different erroref_words: ins, del, sub. + pos_hyp, pos_ref = len(hyp_words), len(ref_words) + wer_info = {'sub': 0, 'ins': 0, 'del': 0, 'nw': len(ref_words)} + aligned_html = '' + matched_ref = '' + while pos_hyp > 0 or pos_ref > 0: + err_type = '' + + # Distinguish error type by back tracking + if pos_ref == 0: + err_type = 'ins' + elif pos_hyp == 0: + err_type = 'del' + else: + if hyp_words[pos_hyp - 1] == ref_words[pos_ref - 1]: + err_type = 'none' # correct error + elif distmat[pos_ref][pos_hyp] == distmat[pos_ref - 1][pos_hyp - 1] + 1: + err_type = 'sub' # substitute error + elif distmat[pos_ref][pos_hyp] == distmat[pos_ref - 1][pos_hyp] + 1: + err_type = 'del' # deletion error + elif distmat[pos_ref][pos_hyp] == distmat[pos_ref][pos_hyp - 1] + 1: + err_type = 'ins' # insersion error + else: + raise ValueError('fail to parse edit distance matrix.') + + # Generate aligned_html + if self._html_handler: + if pos_hyp == 0 or not hyp_words: + tmph = ' ' + else: + tmph = hyp_words[pos_hyp - 1] + if pos_ref == 0 or not ref_words: + tmpr = ' ' + else: + tmpr = ref_words[pos_ref - 1] + aligned_html = self._html_handler(tmph, tmpr, err_type) + aligned_html + + # If no error, go to previous ref and hyp. + if err_type == 'none': + matched_ref = hyp_words[pos_hyp - 1] + ' ' + matched_ref + pos_hyp, pos_ref = pos_hyp - 1, pos_ref - 1 + continue + + # Update error. + wer_info[err_type] += 1 + + # Adjust position of ref and hyp. + if err_type == 'del': + pos_ref = pos_ref - 1 + elif err_type == 'ins': + pos_hyp = pos_hyp - 1 + else: # err_type == 'sub' + pos_hyp, pos_ref = pos_hyp - 1, pos_ref - 1 + + # Verify the computation of edit distance finishes + assert distmat[-1][-1] == wer_info['ins'] + \ + wer_info['del'] + wer_info['sub'] + + # Accumulate err_info before the next (hyp, ref). + for k in wer_info: + self.wer_info[k] += wer_info[k] + + # Collect aligned_htmls. + if self._html_handler: + self.aligned_htmls += [aligned_html] + + # Update key phrase info. + if self.key_phrases: + for w in self.key_phrases: + self.ref_keyphrase_counts[w] += reference.count(w) + self.hyp_keyphrase_counts[w] += hypothesis.count(w) + self.matched_keyphrase_counts[w] += matched_ref.count(w) + + def GetWER(self): + """Compute Word Error Rate (WER) to summarize word erroref_words. + + Note WER can be larger than 100.0, esp when there are many insertion errors. + + Returns: + WER as percentage number, usually between 0.0 to 100.0 + """ + nref = self.wer_info['nw'] + nref = max(1, nref) # non_zero value for division + total_error = self.wer_info['ins'] \ + + self.wer_info['del'] + self.wer_info['sub'] + return total_error * 100.0 / nref + + def GetKeyPhraseStats(self): + """Measure the Jaccard similarity of key phrases between hyps and refs. + + Returns: + jaccard_similarity: jaccard similarity, between 0.0 and 1.0 + F1_keyphrase: F1 score (=2/(1/prec + 1/recall)), between 0.0 and 1.0 + matched_keyphrases: num of matched key phrases. + ref_keyphrases: num of key phrases in the reference strings. + hyp_keyphrases: num of key phrases in the hypothesis strings. + """ + + matched_k = sum(self.matched_keyphrase_counts.values()) + ref_k = sum(self.ref_keyphrase_counts.values()) + hyp_k = sum(self.hyp_keyphrase_counts.values()) + joined_k = ref_k + hyp_k - matched_k + joined_k = max(1, joined_k) # non_zero value for division + jaccard_similarity = matched_k * 1.0 / joined_k + + f1_k = 2.0 * matched_k / max(ref_k + hyp_k, 1.0) + return (jaccard_similarity, f1_k, matched_k, ref_k, hyp_k) + + def GetSummaries(self): + """Generate strings to summarize word errors and key phrase errors. + + Returns: + str_sum: string summarizing total error, total word and WER. + str_details: string breaking down three error types: del, ins, sub. + str_str_keyphrases_info: string summarizing kerphrase information. + """ + wer = self.GetWER() + nref = self.wer_info['nw'] + total_error = self.wer_info['ins'] \ + + self.wer_info['del'] + self.wer_info['sub'] + str_sum = 'WER = %.2f%% (%.4f%%), total error = %d, total word = %d' % ( + wer, wer, total_error, nref) + + str_details = 'Error breakdown: del = %.2f%%, ins=%.2f%%, sub=%.2f%%' % ( + self.wer_info['del'] * 100.0 / nref, self.wer_info['ins'] * 100.0 / + nref, self.wer_info['sub'] * 100.0 / nref) + + str_keyphrases_info = '' + if self.key_phrases: + jaccard_p, f1_p, matched_p, ref_p, hyp_p = self.GetKeyPhraseStats() + str_keyphrases_info = ('matched %d key phrases (%d in ref, %d in hyp), ' + 'jaccard similarity=%.2f, F1=%.2f') % \ + (matched_p, ref_p, hyp_p, jaccard_p, f1_p) + + return str_sum, str_details, str_keyphrases_info, (wer, total_error, nref) + + def write_html(self, fn_output): + try: + aligned_html = '\n'.join('

  • {}
  • '.format(html) for html in self.aligned_htmls) + aligned_html = '
      {}
    '.format(aligned_html) + with open(fn_output, 'wt') as fp: + fp.write('') + fp.write('') + fp.write('
    %s
    ' % aligned_html) + fp.write('') + except IOError: + print('failed to write diagnosis html') + + +def main(argv): + hypothesis = open(argv[1], 'r').read() + reference = open(argv[2], 'r').read() + + if len(argv) == 4: + phrase_lines = open(argv[3]).readlines() + keyphrases = [line.strip() for line in phrase_lines] + else: + keyphrases = None + + wer_obj = SimpleWER( + key_phrases=keyphrases, + html_handler=HighlightAlignedHtml, + preprocess_handler=RemoveCommentTxtPreprocess) + + wer_obj.AddHypRef(hypothesis, reference) + + str_summary, str_details, str_keyphrases_info = wer_obj.GetSummaries() + print(str_summary) + print(str_details) + print(str_keyphrases_info) + + try: + fn_output = argv[1] + '_diagnosis.html' + aligned_html = '
    '.join(wer_obj.aligned_htmls) + with open(fn_output, 'wt') as fp: + fp.write('') + fp.write('
    %s
    ' % aligned_html) + fp.write('') + except IOError: + print('failed to write diagnosis html') + + +if __name__ == '__main__': + if len(sys.argv) < 3 or len(sys.argv) > 4: + print(""" +Example of Usage: + + python simple_wer_v2.py file_hypothesis file_reference +or + python simple_wer_v2.py file_hypothesis file_reference file_keyphrases + + where file_hypothesis is the file name for hypothesis text + file_reference is the file name for reference text. + file_keyphrases (optional) is the filename of key phrases over which + you want to measure accuracy. + +Or you can use this file as a library, and call class SimpleWER + .AddHypRef(hyp, ref): add one pair of hypothesis/reference. You can call this + function multiple times. + .GetWER(): get the Word Error Rate (WER). + .GetKeyPhraseStats(): get stats for key phrases. The first value is Jaccard + Similarity of key phrases. + .GetSummaries(): generate strings to summarize word error and + key phrase errors. +""") + sys.exit(1) + + main(sys.argv) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spec2vec.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spec2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..272580476d9b544e70e7d6c708b5edb35dc27887 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spec2vec.py @@ -0,0 +1,210 @@ +from typing import List + +import torch +from torch import nn + +from nemo.collections.asr.models.configs import common_config as common_cfg +from nemo.collections.asr.models.spec2vec.spec2vec_config import ConvTransformerBlock +from nemo.collections.asr.parts.convolution_layers import ConvNormAct, create_pad_mask, Conv +from nemo.collections.asr.parts.wav2vec import TransformerEncoder + + +class FeatureEncoder(nn.Module): + def __init__(self, feat_in, use_conv_mask, conv2d_block: common_cfg.Conv2dBlock, + conv_transformer_blocks: List[ConvTransformerBlock], + use_tf_pad: bool, ln_eps: float = 1e-5): + super().__init__() + + self.use_conv_mask = use_conv_mask + + self.bn_moudles = [] + + if conv2d_block: + prev_out_channels = 1 + self.conv2d_block = nn.ModuleList() + for conv2d_cfg_i in conv2d_block.layers: + layer = ConvNormAct(in_channels=prev_out_channels, + conv_type='2d', + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **conv2d_cfg_i) + + if isinstance(layer.norm, (nn.BatchNorm2d, nn.BatchNorm1d)): + self.bn_moudles.append(layer.norm) + + prev_out_channels = conv2d_cfg_i.filters + self.conv2d_block.append(layer) + prev_out_channels = conv2d_block.output_dim + self.conv2d_block.apply(kaiming_init_conv_weights) + else: + self.conv2d_block = None + prev_out_channels = feat_in + + self.block_modules = nn.ModuleList() + for block_cfg in conv_transformer_blocks: + for conv_cfg_i in block_cfg.conv_layers: + layer = ConvNormAct(in_channels=prev_out_channels, + conv_type='1d', + use_tf_pad=use_tf_pad, + ln_eps=ln_eps, + **conv_cfg_i) + + if isinstance(layer.norm, (nn.BatchNorm2d, nn.BatchNorm1d)): + self.bn_moudles.append(layer.norm) + + prev_out_channels = conv_cfg_i.filters + layer.apply(kaiming_init_conv_weights) + self.block_modules.append(layer) + + if block_cfg.transformer_block is not None: + block = TransformerEncoder(block_cfg.transformer_block) + self.block_modules.append(block) + prev_out_channels = block_cfg.transformer_block.encoder.embedding_dim + + self.output_dim = prev_out_channels + + def forward(self, audio_signal, length): + # [B, F/D, T] + output = audio_signal + + if self.use_conv_mask: + pad_mask = create_pad_mask(length, max_len=output.size(2)) + else: + pad_mask = None + + if self.conv2d_block is not None: + # [B, F, T] => [B, T, F] =>[B, C, T, F] + output = torch.transpose(output, 1, 2).unsqueeze(1) + for module in self.conv2d_block: + output, length, pad_mask = module(output, length, pad_mask=pad_mask) + b, c, t, f = output.size() + # [B, C, T, F] => [B, F, C, T] => [B, FxC/D, T] + output = output.permute(0, 3, 1, 2).reshape(b, f * c, t) + + for module in self.block_modules: + if isinstance(module, ConvNormAct): + output, length, pad_mask = module(output, length, pad_mask=pad_mask) + else: + assert isinstance(module, TransformerEncoder) + # [B, D, T] => [B, T, D] + output = output.transpose(1, 2) + output = module(output, padding_mask=pad_mask) + # [B, T, D] => [B, D, T] + output = output.transpose(1, 2) + + return output, length, None + + def bn_eval(self): + for m in self.bn_moudles: + m.eval() + + def get_subsampled_lens(self, lens): + if self.conv2d_block is not None: + for module in self.conv2d_block: + lens = module.update_out_seq_lens(lens) + + for module in self.block_modules: + if isinstance(module, ConvNormAct): + lens = module.update_out_seq_lens(lens) + + return lens + + +class Projector(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.use_conv_mask = cfg.use_conv_mask + + prev_out_channels = cfg.input_dim + + if cfg.conv_layers is not None: + self.conv_layers = nn.ModuleList() + for conv_cfg_i in cfg.conv_layers: + assert conv_cfg_i.stride == (1,) + layer = ConvNormAct(in_channels=prev_out_channels, + conv_type='1d', + use_tf_pad=cfg.use_tf_pad, + ln_eps=cfg.ln_eps, + **conv_cfg_i) + prev_out_channels = conv_cfg_i.filters + self.conv_layers.append(layer) + self.conv_layers.apply(kaiming_init_conv_weights) + else: + self.conv_layers = None + + self.transformer = None if cfg.transformer is None else TransformerEncoder(cfg.transformer) + + if cfg.output_dim is not None: + self.output_proj = nn.Linear(prev_out_channels, cfg.output_dim) + self.output_dim = cfg.output_dim + else: + self.output_proj = None + self.output_dim = prev_out_channels + + def forward(self, inputs, length): + # [B, T, D] + assert inputs.shape[0] == length.shape[0] + output = inputs + + if (self.conv_layers is not None and self.use_conv_mask) or self.transformer is not None: + pad_mask = create_pad_mask(length, max_len=output.size(1)) + else: + pad_mask = None + + if self.conv_layers is not None: + # [B, T, D] => [B, D, T] + output = output.transpose(1, 2) + for conv_i in self.conv_layers: + output, length, pad_mask = conv_i(output, length, pad_mask=pad_mask) + # [B, D, T] => [B, T, D] + output = output.transpose(1, 2) + + if self.transformer is not None: + assert pad_mask is not None + output = self.transformer(output, padding_mask=pad_mask) + + if self.output_proj is not None: + output = self.output_proj(output) + + return output + + +def kaiming_init_conv_weights(m): + if isinstance(m, (nn.Conv1d, nn.Conv2d)): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): + pass # use default init + elif isinstance(m, (nn.Dropout, nn.ReLU, nn.Sequential, nn.ModuleList)): + pass # ignore modules do not need init + elif isinstance(m, (FeatureEncoder, ConvNormAct, Conv)): + pass # ignore wrapper modules + else: + raise ValueError('initializing unknown module type {}'.format(type(m))) + + +class RandomMask(nn.Module): + def __init__(self, prob, mask_value=None, mask_dim=None): + super().__init__() + assert 0 <= prob < 1 + self.prob = prob + if mask_value is not None: + assert mask_dim is None + self.mask_value = mask_value + self.embedding_mask = False + else: + assert isinstance(mask_dim, int) + self.mask_value = nn.Parameter(torch.FloatTensor(mask_dim).uniform_()) + self.embedding_mask = True + + def forward(self, inputs: torch.Tensor): + if not self.training: + return inputs + + if self.embedding_mask: + mask_shape = inputs.size()[:-1] + else: + mask_shape = inputs.size() + mask_indices = torch.bernoulli(torch.full(mask_shape, self.prob, device=inputs.device)).type(torch.bool) + inputs[mask_indices] = self.mask_value + return inputs diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spectr_augment.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spectr_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..4e2b1cc0cea2291af4a1a0f5761d131a68ae5580 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spectr_augment.py @@ -0,0 +1,144 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.nn as nn + + +GAUSSIAN_MASK = [-0.008577285334467888, -0.008312131278216839, -0.003433679463341832, -0.007991528138518333, 0.022046754136681557, -0.014587175101041794, 0.010684879496693611, 0.0032962223049253225, 0.0006217826739884913, -0.0109024653211236, 0.003951661288738251, 0.006253963336348534, 0.005004990380257368, -0.004416681360453367, -0.014572846703231335, -0.006137363612651825, 0.015708647668361664, 0.013687868602573872, -0.0034236200153827667, -0.017063571140170097, -0.005857351701706648, 0.00830098520964384, 0.001996119972318411, 0.0001801295584300533, 0.007846095599234104, 0.006822641938924789, 0.0053718010894954205, -0.006921472027897835, -0.008087077178061008, 0.0019816632848232985, 0.017332371324300766, 0.0013035658048465848, -0.005086318589746952, 0.0038084506522864103, -0.0028235530480742455, 0.004277005326002836, -0.008790107443928719, -0.01645108126103878, -0.00870309118181467, 0.0030529664363712072, 0.0044243973679840565, -0.001729264622554183, 0.0018796413205564022, 0.001113063539378345, 0.001685396651737392, 0.007856708019971848, -0.008503658697009087, 0.004928226582705975, 0.003501072758808732, 0.013370650820434093, 0.015068281441926956, -0.008089980110526085, -0.00913938321173191, 0.010570412501692772, -0.0008485731086693704, -0.0017308632377535105, 0.009554422460496426, 0.0008375696488656104, 0.01573362946510315, -0.00033936771797016263, -0.00738796591758728, 0.004305792506784201, 0.01625652238726616, 0.004601571708917618, 0.0033054756931960583, -0.006255044601857662, -0.004542464856058359, 0.013747476041316986, -0.010799456387758255, 0.0024993526749312878, 0.005865572020411491, -0.013002256862819195, 0.005194664001464844, -0.021207045763731003, 0.007146795745939016, 0.001368773402646184, -0.014667760580778122, 0.010809154249727726, -0.005281668156385422, 0.009397738613188267, -0.0117409098893404, 0.010593295097351074, 0.00379750388674438, 0.005227712448686361, 0.005195990204811096, 0.007941234856843948, -0.008073929697275162, 0.0005659427843056619, -0.002949729096144438, 0.004707518499344587, 0.00013310337089933455, -0.0021645540837198496, 0.005021876189857721, 0.006173995789140463, 0.007487660273909569, 0.00892797950655222, 0.008880391716957092, 0.017241517081856728, 0.007244352716952562, -0.0020703296177089214, -0.003922614734619856, -0.017371725291013718, 0.0036656244192272425, 0.0008915121434256434, -0.010720319114625454, -0.0029386195819824934, -0.004162450321018696, -0.003555792849510908, -0.009222142398357391, 0.008136272430419922, -0.016355125233530998, -0.0019096146570518613, 0.009561257436871529, -0.0009358802926726639, -0.009368732571601868, -0.0006803714786656201, -0.0032679459545761347, -0.010729965753853321, -0.006920733489096165, -0.007359025534242392, -0.0036679436452686787, -0.01085688453167677, 0.0032564126886427402, -0.00396075751632452, 0.003947695717215538, 0.006906456314027309, -0.00874226726591587, -0.010531643405556679, 0.002544721122831106, 0.003992108628153801, 0.002931957133114338, -0.008593415841460228, -0.00834380928426981, -0.0004008698451798409, -0.007474125362932682, -0.0018838047981262207, -0.010623954236507416, 0.0004698234552051872, -0.0064691524021327496, -0.001641935552470386, 0.030333172529935837, 0.01937473751604557, 0.008528976701200008, -0.015279320999979973, -0.009957612492144108, 0.004345834255218506, 0.0052221533842384815, 0.028346601873636246, 0.00015258256462402642, 0.017097460106015205, -0.011207249946892262, -0.0013319909339770675, -0.01574462465941906, -0.001186217530630529, 0.002741827629506588, -0.0027266093529760838, -0.0019425811478868127, -0.016291635110974312, 0.0014635918196290731, -0.0068401917815208435, -0.0023575483355671167, -0.023059917613863945, -0.004130368120968342, 0.018156377598643303, -0.010979206301271915, 0.0030661290511488914, 0.001635069027543068, -0.013064907863736153, -0.0013261985732242465, 0.008750525303184986, 0.01025752630084753, 0.006366524379700422, -0.019202180206775665, 0.012242366559803486, 0.005239878781139851, 0.006112856324762106, -0.0009505008929409087, 0.007498623337596655, 0.004345917142927647, 0.0019737775437533855, 0.01724175363779068, -0.011013401672244072, 0.01421738788485527, -0.00010386518988525495, 0.0005218135192990303, -7.810740498825908e-05, -0.0013396443100646138, 0.00891269650310278, 0.006826599128544331, -0.005983854178339243, -0.008014648221433163, -0.004054467659443617, 0.003911960404366255, 0.012865566648542881, 0.006723911035805941, 0.00459368946030736, 0.008541394025087357, 0.014683294110000134, 0.017786560580134392, -0.006093125324696302, 0.0011138939298689365, 0.005661547649651766, -0.001207416644319892, -0.008716529235243797, 0.004718406591564417, -0.004078548867255449, 0.004175822716206312, -0.007703948300331831, 0.01632404327392578, -0.00023797148605808616, 0.0023386836983263493, 0.0049704634584486485, 0.01812780275940895, 0.008168922737240791, -0.01889769174158573, -0.007803930900990963, 0.004977164324373007, 0.005988257005810738, -0.0024359619710594416, 0.015691304579377174, -0.004170612432062626, 0.011284936219453812, -0.012687956914305687, -0.013468815013766289, -0.0056364708580076694, 0.003777650184929371, -0.010477692820131779, -3.670304431580007e-05, 0.002994526643306017, -0.02478531375527382, -0.0186957735568285, -0.000980377197265625, -0.006090754177421331, -0.0008337362669408321, 0.0042216661386191845, -0.0017890859162434936, -0.017191831022500992, 0.007558343932032585, -0.0004839670436922461, 0.0006820259150117636, -0.0036172219552099705, 0.012755093164741993, 0.007131294813007116, 0.01682445965707302, -0.0033714391756802797, 0.011848727241158485, -0.008541670627892017, 0.005448013544082642, 0.005193810444325209, 0.0003855297400150448, 0.009456363506615162, -0.004641769919544458, 0.025975538417696953, -0.008149398490786552, -0.014135519042611122, 0.006781088653951883, 0.004117966629564762, -0.006494637578725815, -0.007128587923943996, 0.006647801958024502, -0.012087219394743443, 0.001748546026647091, 0.01748277060687542, -0.004554575774818659, 0.0036450892221182585, -0.010104892775416374, -0.0035864445380866528, 0.0058875922113657, -0.0011381434742361307, 0.0015288969734683633, 0.0013207353185862303, -0.003274752525612712, 0.0023052070755511522, -0.011285059154033661, 0.011404736898839474, -0.001985494513064623, -0.008716653101146221, 0.005484268069267273, -0.0021095615811645985, 0.015584509819746017, -0.008399765007197857, 0.007663742173463106, -0.0046447026543319225, 0.01980607770383358, -0.021430866792798042, 0.0019235932268202305, 0.01850181445479393, 0.014331997372210026, 0.004704942926764488, -0.004087083041667938, -0.0052873240783810616, -0.004040342289954424, -0.0009587344247847795, -0.001312632579356432, 0.006471287924796343, -0.004858402069658041, 0.006215596571564674, -0.01835842989385128, -0.013215125538408756, -0.012908847071230412, -0.005191713571548462, 0.005219723097980022, -0.0015103589976206422, -0.006740736309438944, -0.005966144613921642, -0.0008368652779608965, 0.003591155633330345, -0.0025086551904678345, 0.018543150275945663, 0.01635683700442314, 0.003782284678891301, -0.012239447794854641, 0.0011241791071370244, -8.145845640683547e-05, 0.01495729386806488, 0.0033017871901392937, 0.00019223176059313118, 0.007682753726840019, 0.011071891523897648, 0.011146822944283485, -0.006912447512149811, 0.0034845354966819286, 0.005456835497170687, 0.002649294678121805, -0.005829342640936375, -0.002240463625639677, 0.0029810150153934956, 0.0026485121343284845, 0.012740016914904118, 0.006942421197891235, 0.008712748065590858, -0.002921474166214466, -0.003271800698712468, -0.0006888209609314799, 0.000602235842961818, -0.009740591049194336, -0.005469066556543112, 0.0027335130143910646, -0.007446215022355318, -0.010410390794277191, -0.020527798682451248, -0.0007794187986291945, -0.010217029601335526, -0.008854899555444717, 0.0018076488049700856, 0.003745997091755271, -0.008995960466563702, 0.006499497219920158, -0.009520279243588448, -0.0013822759501636028, 0.005432894919067621, 0.0048836348578333855, -0.006969304755330086, 0.009624844416975975, 0.002221961971372366, 0.012857971712946892, 0.00029701043968088925, 0.01846439018845558, -0.0010167461587116122, 0.003362444229424, 0.018802843987941742, 0.007899937219917774, 0.007362710777670145, -0.01259714737534523, -0.004353848285973072, 0.009529988281428814, -0.010434559546411037, -0.00114968151319772, -0.015083436854183674, -0.000513288367073983, -0.012708902359008789, 0.009817283600568771, -0.01616625115275383, 0.02099389024078846, 0.0011255480349063873, 0.002087818691506982, 0.004141946788877249, -0.004236323293298483, -0.008043688721954823, -0.01005634292960167, -0.01211725827306509, 0.008502005599439144, -0.003608573926612735, 0.0013564183609560132, 0.0025668770540505648, 0.006136550568044186, 0.01201977115124464, 0.025350334122776985, 0.012129077687859535, -0.0032070353627204895, 0.0003363523574080318, -0.00036432372871786356, 6.168012623675168e-05, -0.010630798526108265, -0.001909107668325305, 0.0031751287169754505, 0.023505495861172676, 0.006265965756028891, -0.000127899824292399, 0.027016283944249153, -0.003139286069199443, 0.007756243925541639, 0.012871683575212955, -0.005639276001602411, 0.006201721262186766, -0.007926560938358307, -0.007784618530422449, 0.0018705694237723947, 0.011534891091287136, 0.003263025777414441, -0.008872064761817455, 0.011395181529223919, -0.0061043002642691135, -0.0020404469687491655, -0.002608739770948887, 0.0014213520335033536, 7.451167039107531e-05, -0.0006841712747700512, 0.00638044998049736, 0.02477930672466755, -0.0001436929014744237, 0.009057887829840183, -0.0005273017450235784, -0.016813546419143677, -0.011913691647350788, 0.008906804956495762, 0.0009794242214411497, -0.002302555600181222, 0.012043703347444534, -0.00899094995111227, 0.0017951279878616333, 0.010412560775876045, -0.0018607425736263394, 0.003462078981101513, 0.010766361840069294, 0.007733880076557398, 0.0037096040323376656, 0.013803163543343544, -0.010908039286732674, 0.012310138903558254, 0.011781050823628902, -0.022002901881933212, 0.01713118888437748, 0.004179463256150484, 0.00031042384216561913, 0.002601897343993187, 0.002513417275622487, -0.0009618049371056259, 0.0004392332921270281, -0.01457865722477436, -0.00014599964197259396, -0.016580622643232346, -0.014996372163295746, 0.0037527403328567743, -0.008046748116612434, 0.0020090779289603233, -0.0033000593539327383, 0.0052381521090865135, 0.001489385380409658, -0.010585324838757515, -0.015648989006876945, -0.0017674925038591027, 0.004076640121638775, 0.019764112308621407, -0.010859061032533646, -0.0004517346096690744, 0.019380798563361168, -0.0069849626161158085, 0.006805767305195332, 0.0010639704996719956, -0.0008604522445239127, 0.014004685916006565, -0.009115546941757202, 0.020015710964798927, 0.0020416032057255507, -0.005136480089277029, -0.014490129426121712, -0.013324854895472527, 0.02655731327831745, -0.006656433921307325, -0.003929227590560913, 0.0036637713201344013, -0.002843528985977173, 0.016735395416617393, -0.0026499521918594837, -9.097589645534754e-05, 0.0009586272644810379, -0.008444109931588173, -0.01267432514578104, -0.00987020693719387, 0.024359915405511856, 0.006355916149914265, -0.010813186876475811, -0.008563755080103874, 0.001825400278903544, 0.015858624130487442, 0.008124123327434063, 0.007248807232826948, 0.0014658113941550255, 0.00889967568218708, -0.018838992342352867, 0.012073985300958157, 0.007954470813274384, 0.0040962728671729565, 0.003913100343197584, -0.009089702740311623, -0.01628614217042923, -0.005144990514963865, 0.006413666065782309, -0.014760496094822884, -0.010384202003479004, -0.013040153309702873, 0.018687430769205093, -0.0015114899724721909, -0.007259591482579708, -0.002887800568714738, 0.01662031002342701, 0.01137789711356163, 0.0007028636755421758, 0.004064011387526989, 0.007000538054853678, -0.005053787026554346, 0.012506005354225636, -0.009795127436518669, 0.0013481989735737443, 0.01153571903705597, 0.0005770407733507454, 0.00843889731913805, -0.004305741284042597, -0.00523682776838541, -0.003950543235987425, -0.030804451555013657, 0.0019242960261180997, -0.001121998648159206, -0.00024091487284749746, -0.011150459758937359, -0.0006252250750549138, -0.008173838257789612, 0.0025749732740223408, 0.012828282080590725, 0.0037352831568568945, 0.017000781372189522, 0.019764423370361328, -0.006141415797173977, 0.009571244940161705, 0.0060485838912427425, -0.003701487323269248, 0.012796318158507347, -0.005578612443059683, -0.002949218498542905, -0.0077390591613948345, 0.013134066946804523, 0.01348984893411398, -0.017807014286518097, -0.007176446728408337, -0.0025080926716327667, 0.0026706154458224773, 0.006270487792789936, -0.001722719520330429, 0.014065185561776161, -0.012098010629415512, -0.006948764901608229, 0.007107093930244446, 0.011564299464225769, 0.015365575440227985, -0.001405196264386177, 0.013922395184636116, -0.011645046062767506, 0.019185440614819527, 0.008669205941259861, -0.006402950268238783, -0.02930634282529354, 0.012674490921199322, 0.0020986138842999935, 0.0073446049354970455, -0.008211275562644005, -0.010451032780110836, 0.0039263502694666386, 0.010214317589998245, 0.005137004889547825, 0.003445018082857132, 0.011561079882085323, 0.014101037755608559, -0.0009882557205855846, -0.020502908155322075, -0.01304234005510807, -0.0010743135353550315, 0.006024991162121296, 0.005487101152539253, 0.015423699282109737, 0.010340548120439053, 0.007406091317534447, -0.002593359677121043, -0.014484986662864685, -0.008927910588681698, 0.01809288188815117, -8.335654797519965e-07, 0.016086341813206673, -0.007467319723218679, 0.006248668301850557, 0.010234993882477283, 0.01105313841253519, 0.017129629850387573, 0.01001526229083538, 0.008971969597041607, 0.0019335917895659804, 0.014964735135436058, 0.0032235209364444017, -0.01558700855821371, -0.0002231745602330193, 1.0325457878934685e-05, 0.005610847380012274, -0.0014725240180268884, -0.007073611952364445, 0.006045978516340256, 0.00661000469699502, 0.016458045691251755, 0.004776420537382364, -0.0024809655733406544, -0.001141632441431284, 0.020380591973662376, 0.001985315466299653, 0.0016066418029367924, -0.0002630813978612423, -0.0023715763818472624, -0.005043766926974058, -0.0006530124810524285, -0.0011994787491858006, -0.005200596526265144, -0.011772070080041885, 0.020897328853607178, -0.0030809161253273487, 0.0151188550516963, -0.002648375928401947, -0.0020464551635086536, -0.011824622750282288, 0.006495086010545492, -0.006923593580722809, -0.010816085152328014, -0.0005119291599839926, 0.019806137308478355, 0.01142488420009613, 0.006749970838427544, -0.015500311739742756, -0.00845906138420105, -0.0021400016266852617, -0.0008869305020198226, 0.016512639820575714, 0.0022560760844498873, -0.003913175780326128, 0.007876270450651646, 0.0010393362026661634, 0.009496960788965225, 0.008087781257927418, 0.0075989654287695885, 0.007650574669241905, 0.011847944930195808, 0.006889567244797945, 0.005724477581679821, 0.003269805572926998, 0.017861217260360718, 0.014128664508461952, -0.018311243504285812, -0.00034174948814325035, 0.010595879517495632, -0.007073213811963797, -0.0037547526881098747, 0.007721452973783016, -0.013909009285271168, 0.009047291241586208, -0.005391568876802921, 0.010471170768141747, -0.005925311706960201, -0.010769817978143692, -0.01213113870471716, 0.013672600500285625, -0.025970296934247017, 0.004848531447350979, -3.2736243156250566e-05, -0.0014015173073858023, -0.006384227890521288, 0.00352313625626266, 0.007666777819395065, 0.011333705857396126, -0.02562958188354969, 0.007899546064436436, -0.0003998324682470411, 0.00041209004120901227, 0.003894905559718609, 0.011146043427288532, 0.012012973427772522, 0.005776166915893555, 0.01267810445278883, 0.015597822144627571, -0.006004557479172945, -0.005119791720062494, 0.004542575217783451, -0.006028160918504, 0.0002760722709354013, 0.008363571017980576, 0.004950536414980888, 0.006884160451591015, 0.002470653271302581, 0.008553436025977135, 0.016060978174209595, -0.004487414378672838, 0.0008720596670173109, -0.0011375041212886572, 0.006429357919842005, -0.014677043072879314, 0.011124462820589542, -0.0023991605266928673, -0.003062682691961527, -0.0004514633328653872, 0.007666073273867369, 0.008637756109237671, -0.009318140335381031, 0.017036888748407364, -0.001699244137853384, -0.001919509842991829, -0.010616443119943142, -0.014744545333087444, 0.0023425209801644087, 0.015598481521010399, 0.004587096627801657, -0.006646877154707909, -0.0003353591891936958, 0.018362227827310562, -0.0008180644363164902, -0.009953021071851254, 0.009997352957725525, 0.006935094483196735, -0.0041261701844632626, 0.008880453184247017, 0.014008709229528904, -0.002823092043399811, 0.00992603413760662, -0.012414450757205486, 0.00843184906989336, 0.0020578710827976465, 0.016727814450860023, 0.019950158894062042, -0.004741964861750603, -0.0005425591953098774, -0.0058597358874976635, -0.0017816543113440275, 0.0027736839838325977, -0.012360996566712856, 0.003590812673792243, 0.005126148462295532, 0.004654150456190109, -0.006369463168084621, -0.003772721393033862, 0.02091693878173828, 0.01528538204729557, -0.003835881594568491, -0.003409450873732567, 0.010528912767767906, -0.0017828766722232103, 0.02516132965683937, 0.00814846158027649, -0.004559976980090141, -0.011295965872704983, 0.010371438227593899, 0.0006707622087560594, 0.005064355209469795, -0.010145595297217369, 0.004426772706210613, -0.005737415049225092, 0.0012003653682768345, 0.0036150291562080383, 0.011215592734515667, -0.0040545896627008915, 0.0034540158230811357, -0.008764381520450115, 0.014239713549613953, -0.0008067164453677833, -0.0077876923605799675, -0.018948623910546303, 0.003514879383146763, 0.0030799901578575373, 0.017965681850910187, -2.2737660401617177e-05, 0.007884345017373562, 0.003565009217709303, -0.012896131724119186, 0.0026204099413007498, -0.0020277798175811768, -0.005030298605561256, 0.0028232778422534466, -0.010799753479659557, 0.0004869748081546277, 0.00079371128231287, 0.0038255134131759405, 0.01625426858663559, -0.007327786646783352, -0.0006097709410823882, 0.021006541326642036, -0.004630157724022865, -0.015508498065173626, -0.011421660892665386, 0.004675465635955334, -0.0020752830896526575, 0.00651256600394845, -0.001388593576848507, -0.018634997308254242, -0.013168826699256897, 0.00600035535171628, -0.0051271733827888966, -0.0009046715567819774, 0.008201603777706623, 0.004723501857370138, -0.003945730160921812, 0.002696358598768711, -0.00534584978595376, 0.014200250618159771, 0.0032823574729263783, -0.009708631783723831, -0.002968631684780121, -0.0013192173792049289, -0.02307787910103798, 0.0048147342167794704, 0.009474620223045349, 0.0016143537359312177, 0.0026406359393149614, -0.004522481467574835, -0.00021635316079482436, 0.003252497175708413, -0.012805595062673092, 0.0034197745844721794, -0.016244668513536453, 0.001594359870068729, 0.01458613108843565, 0.010125475935637951, -0.009164906106889248, 0.006422918755561113, 0.002428187755867839, -0.005929546430706978, -0.012289043515920639, -0.00899637583643198, 6.348137685563415e-05, 0.01821526326239109, -0.0014832826564088464, -0.010244476608932018, -0.0010153147159144282, -0.005283149890601635, -0.012473183684051037, 0.007026445120573044, -0.011461296118795872, 0.0033159609884023666, -0.01289116870611906, 0.01181294210255146, 0.012677633203566074, -0.0013208106392994523, 0.003216641489416361, -0.012547919526696205, 0.004124876111745834, 0.0009470342774875462, -0.002027716487646103, 0.02221308834850788, -0.0008846950950101018, -0.0014873967738822103, 0.0159499179571867, 0.006823782809078693, -0.014584256336092949, -0.008290774188935757, -0.0120676439255476, -0.005099371075630188, -0.0080600306391716, 0.014079704880714417, -0.0020367265678942204, -0.008362890221178532, -0.001621987670660019, 0.010619360022246838, 0.013944074511528015, 0.0009812230709940195, 0.01836475357413292, 0.00273580988869071, -0.008425148203969002, -0.009124254807829857, 0.00011686794459819794, 0.0009999654721468687, 0.015176774933934212, -0.0009844052838161588, 0.013951435685157776, 0.004029604606330395, 0.006666754838079214, 0.003529589157551527, 0.003615013789385557, -0.02800256945192814, 0.006249892991036177, -0.008050324395298958, 0.005056394264101982, -0.006130233872681856, 0.006157447583973408, -0.010642053559422493, 0.0012126392684876919, -0.006806216202676296, 0.004176209215074778, -0.006571719888597727, -0.009268700145184994, -0.0075151631608605385, -0.0015292258467525244, -0.010599321685731411, 0.003558061085641384, 0.007374519016593695, 0.003586599137634039, 0.017905788496136665, 0.008015276864171028, -0.009034614078700542, 0.0018911826191470027, -0.005265430081635714, 0.004869093652814627, -0.0001151503311120905, -0.003885870799422264, 0.004713456612080336, 0.005261662881821394, -0.006401803344488144, 0.004026818089187145, 0.000506703567225486, 0.01196456328034401, -0.007442399859428406, 0.007578311488032341, 0.011165248230099678, -0.0009072477114386857, 0.010882778093218803, -0.0029303899500519037, 0.0024887293111532927, 0.004268537741154432, -0.003925780300050974, 0.01125251967459917, 0.0056076375767588615, 0.0013249896001070738, -0.0064854929223656654, 0.008991091512143612, 0.006860945839434862, 0.01103312149643898, 0.009761952795088291, -0.010707248002290726, -0.013999280519783497, 0.00729756522923708, -0.011525528505444527, 0.007183250039815903, 0.020575670525431633, -0.0008294707513414323, 0.002093465067446232, 0.005759619176387787, -0.0024755089543759823, 0.006597691681236029, 0.01883961632847786, -0.0017893409822136164, 0.004597699735313654, 0.0020806791726499796, -0.016394654288887978, 0.00436061155050993, -0.0021841854322701693, -0.01360281091183424, 0.0053757247515022755, 0.0006083177286200225, 0.0009018208365887403, 0.023406485095620155, -0.0009405831224285066, 0.018612656742334366, 0.0033334933687001467, -0.003511708928272128, -0.0026272705290466547, -0.0007344239274971187, 0.00568321393802762, -0.0053346729837358, -0.0017605028115212917, -0.010383752174675465, 0.018871944397687912, 0.0006560813635587692, -0.010139352641999722, 0.0026594139635562897, -0.0033114543184638023, 0.017703739926218987, -0.00698117958381772, -0.0020041917450726032, -0.004826934076845646, 0.0005630375235341489, 0.0032594038639217615, -0.0009713696199469268, -0.008052065037190914, -0.0027339174412190914, -0.010677228681743145, 0.004025970585644245, -0.012248994782567024, 0.015234006568789482, -0.005934533197432756, 0.01589520275592804, -0.012386002577841282, 0.009487216360867023, -0.0009444615570828319, 0.0125525938346982, -0.0013922285288572311, 0.013387584127485752, 0.0037463558837771416, -0.007817583158612251, 0.01384445745497942, -0.005437555257230997, -0.0015223644440993667, -0.015931112691760063, -0.010777517221868038, 0.02394471876323223, 0.011097057722508907, 0.010094624944031239, 0.012511318549513817, 0.0038457082118839025, 0.009449491277337074, 0.005493646953254938, 0.011773099191486835, 0.0019221196416765451, -0.013554790988564491, 0.01794261485338211, -0.0231052003800869, 0.008189410902559757, 0.006517274770885706, -0.004495919682085514, 0.001762248226441443, 0.008912311866879463, 0.009711232967674732, -0.00761059345677495, 0.02114887163043022] + + +class SpecAugment(nn.Module): + """ + Zeroes out(cuts) random continuous horisontal or + vertical segments of the spectrogram as described in + SpecAugment (https://arxiv.org/abs/1904.08779). + + params: + freq_masks - how many frequency segments should be cut + time_masks - how many time segments should be cut + freq_width - maximum number of frequencies to be cut in one segment + time_width - maximum number of time steps to be cut in one segment. + Can be a positive integer or a float value in the range [0, 1]. + If positive integer value, defines maximum number of time steps + to be cut in one segment. + If a float value, defines maximum percentage of timesteps that + are cut adaptively. + """ + + def __init__( + self, freq_masks=0, time_masks=0, freq_width=10, time_width=10, max_time_masks=20, gauss_mask_std=0.0, rng=None, + ): + super(SpecAugment, self).__init__() + + self._rng = random.Random() if rng is None else rng + + self.freq_masks = freq_masks + self.time_masks = time_masks + self.max_time_masks = max_time_masks + + self.freq_width = freq_width + self.time_width = time_width + + self.gauss_mask_std = gauss_mask_std + + if isinstance(time_width, int): + self.adaptive_temporal_width = False + else: + if time_width > 1.0 or time_width < 0.0: + raise ValueError('If `time_width` is a float value, must be in range [0, 1]') + + self.adaptive_temporal_width = True + + if isinstance(time_masks, int): + self.adaptive_time_mask = False + else: + if time_masks >= 1.0 or time_masks < 0.0: + raise ValueError('If `time_width` is a float value, must be in range [0, 1]') + + self.adaptive_time_mask = True + + @torch.no_grad() + def forward(self, x, length): + B, D, T = x.shape + + for idx in range(B): + for _ in range(self.freq_masks): + x_left = self._rng.randint(0, D - self.freq_width) + + w = self._rng.randint(0, self.freq_width) + + x[idx, x_left : x_left + w, :] = 0.0 + + if self.adaptive_temporal_width: + time_width = max(1, int(length[idx] * self.time_width)) + else: + time_width = self.time_width + + if self.adaptive_time_mask: + time_masks = int(length[idx] * self.time_masks) + time_masks = min(time_masks, self.max_time_masks) + else: + time_masks = self.time_masks + + for _ in range(time_masks): + y_left = self._rng.randint(0, length[idx] - time_width) + + w = self._rng.randint(0, time_width) + + if self.gauss_mask_std == 0: + x[idx, :, y_left:y_left + w] = 0.0 + else: + x[idx, :, y_left:y_left + w] = torch.normal(mean=0, std=self.gauss_mask_std, size=(D, w)).to(x.device) + + return x + + +class SpecCutout(nn.Module): + """ + Zeroes out(cuts) random rectangles in the spectrogram + as described in (https://arxiv.org/abs/1708.04552). + + params: + rect_masks - how many rectangular masks should be cut + rect_freq - maximum size of cut rectangles along the frequency dimension + rect_time - maximum size of cut rectangles along the time dimension + """ + + def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None): + super(SpecCutout, self).__init__() + + self._rng = random.Random() if rng is None else rng + + self.rect_masks = rect_masks + self.rect_time = rect_time + self.rect_freq = rect_freq + + @torch.no_grad() + def forward(self, x): + sh = x.shape + + for idx in range(sh[0]): + for i in range(self.rect_masks): + rect_x = self._rng.randint(0, sh[1] - self.rect_freq) + rect_y = self._rng.randint(0, sh[2] - self.rect_time) + + w_x = self._rng.randint(0, self.rect_time) + w_y = self._rng.randint(0, self.rect_freq) + + x[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0 + + return x diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/subsampling.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..5477982e3f00feaaa1b7b7fa589ededdea7931dd --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/subsampling.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + + +class ConvSubsampling(torch.nn.Module): + """Convolutional subsampling which supports VGGNet and striding approach introduced in: + VGGNet Subsampling: https://arxiv.org/pdf/1910.12977.pdf + Striding Subsampling: + "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al. + Args: + subsampling (str): The subsampling technique from {"vggnet", "striding"} + subsampling_factor (int): The subsampling factor which should be a power of 2 + feat_in (int): size of the input features + feat_out (int): size of the output features + conv_channels (int): Number of channels for the convolution layers. + activation (Module): activation function, default is nn.ReLU() + """ + + def __init__(self, subsampling, subsampling_factor, feat_in, feat_out, conv_channels, activation=nn.ReLU()): + super(ConvSubsampling, self).__init__() + self._subsampling = subsampling + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + + in_channels = 1 + layers = [] + if subsampling == 'vggnet': + self._padding = 0 + self._stride = 2 + self._kernel_size = 2 + self._ceil_mode = True + + for i in range(self._sampling_num): + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1 + ) + ) + layers.append(activation) + layers.append( + torch.nn.Conv2d( + in_channels=conv_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1 + ) + ) + layers.append(activation) + layers.append( + torch.nn.MaxPool2d( + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._padding, + ceil_mode=self._ceil_mode, + ) + ) + in_channels = conv_channels + elif subsampling == 'striding': + self._padding = 0 + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + for i in range(self._sampling_num): + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + in_length = feat_in + for i in range(self._sampling_num): + out_length = calc_length( + length=int(in_length), + padding=self._padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + ) + in_length = out_length + + self.out = torch.nn.Linear(conv_channels * out_length, feat_out) + self.conv = torch.nn.Sequential(*layers) + + def forward(self, x, lengths): + x = x.unsqueeze(1) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + # TODO: improve the performance of length calculation + new_lengths = lengths + for i in range(self._sampling_num): + new_lengths = [ + calc_length( + length=int(length), + padding=self._padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + ) + for length in new_lengths + ] + + new_lengths = torch.IntTensor(new_lengths).to(lengths.device) + return x, new_lengths + + +def calc_length(length, padding, kernel_size, stride, ceil_mode): + """ Calculates the output length of a Tensor passed through a convolution or max pooling layer""" + if ceil_mode: + length = math.ceil((length + (2 * padding) - (kernel_size - 1) - 1) / float(stride) + 1) + else: + length = math.floor((length + (2 * padding) - (kernel_size - 1) - 1) / float(stride) + 1) + return length diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_beam_decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_beam_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..5f8cb084b36262d45753dc7f34327c5431f0c142 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_beam_decoding.py @@ -0,0 +1,869 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import numpy as np +import torch +from tqdm import tqdm + +from nemo.collections.asr.modules import rnnt_abstract, TransformerTDecoder +from nemo.collections.asr.parts import rnnt_utils +from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType + + +class BeamTransformerTInfer(Typing): + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": NeuralType(elements_type=HypothesisType())} + + def __init__( + self, + decoder_model: TransformerTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + beam_size: int, + beam_temperature: float, + beam_combine_path: bool, + beam_max_exp_step: int, + beam_prune_exp: bool, + beam_prune_exp_full: bool, + beam_word_reward_ratio: float, + search_type: str = 'convtt', + score_norm: bool = True, + return_best_hypothesis: bool = True, + tsd_max_sym_exp_per_step: Optional[int] = 50, + alsd_max_target_len: Union[int, float] = 1.0, + nsc_max_timesteps_expansion: int = 1, + nsc_prefix_alpha: int = 1, + ): + """ + Beam Search implementation ported from ESPNet implementation - + https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py + + Sequence level beam decoding or batched-beam decoding, performed auto-repressively + depending on the search type chosen. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + + beam_size: number of beams for beam search. Must be a positive integer >= 1. + If beam size is 1, defaults to stateful greedy search. + This greedy search might result in slightly different results than + the greedy results obtained by GreedyRNNTInfer due to implementation differences. + + For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer. + + search_type: str representing the type of beam search to perform. + Must be one of ['beam', 'tsd', 'alsd']. 'nsc' is currently not supported. + + Algoritm used: + `beam` - basic beam search strategy. Larger beams generally result in better decoding, + however the time required for the search also grows steadily. + + `tsd` - time synchronous decoding. Please refer to the paper: + [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + for details on the algorithm implemented. + + Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. + For longer sequences, T is greater, and can therefore take a long time for beams to obtain + good results. This also requires greater memory to execute. + + `alsd` - alignment-length synchronous decoding. Please refer to the paper: + [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + for details on the algorithm implemented. + + Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth + factor of T + U_max, where U_max is the maximum target length expected during execution. + + Generally, T + U_max < T * max_symmetric_expansions. However, ALSD beams are non-unique, + therefore it is required to use larger beam sizes to achieve the same (or close to the same) + decoding accuracy as TSD. + + For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD. + + score_norm: bool, whether to normalize the scores of the log probabilities. + + return_best_hypothesis: bool, decides whether to return a single hypothesis (the best out of N), + or return all N hypothesis (sorted with best score first). The container class changes based + this flag - + When set to True (default), returns a single Hypothesis. + When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis. + + # The following arguments are specific to the chosen `search_type` + + tsd_max_sym_exp_per_step: Used for `search_type=tsd`. The maximum symmetric expansions allowed + per timestep during beam search. Larger values should be used to attempt decoding of longer + sequences, but this in turn increases execution time and memory usage. + + alsd_max_target_len: Used for `search_type=alsd`. The maximum expected target sequence length + during beam search. Larger values allow decoding of longer sequences at the expense of + execution time and memory. + + # The following two flags are placeholders and unused until `nsc` implementation is stabilized. + + nsc_max_timesteps_expansion: Unused int. + + nsc_prefix_alpha: Unused int. + """ + self.decoder = decoder_model + self.joint = joint_model + + self.blank = blank_index + assert self.blank == self.decoder.blank_idx + self.vocab_size = decoder_model.vocab_size + self.search_type = search_type + self.return_best_hypothesis = return_best_hypothesis + + if beam_size < 1: + raise ValueError("Beam search size cannot be less than 1!") + + self.beam_size = beam_size + self.score_norm = score_norm + + self.beam_stepwise_ln_alpha = 0. + self.beam_word_reward_ratio = beam_word_reward_ratio + if self.beam_stepwise_ln_alpha > 0: + assert self.score_norm + assert self.beam_word_reward_ratio == 0 + if self.beam_word_reward_ratio > 0: + assert self.score_norm + assert self.beam_stepwise_ln_alpha == 0 + self.beam_combine_path = beam_combine_path + self.beam_max_exp_step = beam_max_exp_step + self.beam_temperature = beam_temperature + self.beam_prune_exp = beam_prune_exp + self.beam_prune_exp_full = beam_prune_exp_full + + assert self.beam_size > 1 and self.search_type == 'convtt' + + if self.beam_size == 1: + self.search_algorithm = self.greedy_search + elif search_type == "default": + self.search_algorithm = self.default_beam_search + elif search_type == "tsd": + self.search_algorithm = self.time_sync_decoding + elif search_type == "alsd": + self.search_algorithm = self.align_length_sync_decoding + elif search_type == "convtt": + self.search_algorithm = self.convtt_beam_search + elif search_type == "nsc": + raise NotImplementedError("`nsc` (Constrained Beam Search) has not been implemented.") + # self.search_algorithm = self.nsc_beam_search + else: + raise NotImplementedError( + f"The search type ({search_type}) supplied is not supported!\n" + f"Please use one of : (default, tsd, alsd, nsc)" + ) + + if tsd_max_sym_exp_per_step is None: + tsd_max_sym_exp_per_step = -1 + + if search_type in ['tsd', 'alsd', 'nsc'] and not self.decoder.blank_as_pad: + raise ValueError( + f"Search type was chosen as '{search_type}', however the decoder module provided " + f"does not support the `blank` token as a pad value. {search_type} requires " + f"the blank token as pad value support in order to perform batched beam search." + f"Please chose one of the other beam search methods, or re-train your model " + f"with this support." + ) + + self.tsd_max_symmetric_expansion_per_step = tsd_max_sym_exp_per_step + self.alsd_max_target_length = alsd_max_target_len + self.nsc_max_timesteps_expansion = nsc_max_timesteps_expansion + self.nsc_prefix_alpha = nsc_prefix_alpha + + @typecheck() + def __call__( + self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor + ) -> Union[Hypothesis, NBestHypotheses]: + """Perform general beam search. + + Args: + encoder_output: Encoded speech features (B, T_max, D_enc) + encoded_lengths: Lengths of the encoder outputs + + Returns: + Either a list containing a single Hypothesis (when `return_best_hypothesis=True`, + otherwise a list containing a single NBestHypotheses, which itself contains a list of + Hypothesis. This list is sorted such that the best hypothesis is the first element. + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + self.decoder.eval() + self.joint.eval() + + with torch.no_grad(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + + hypotheses = [] + with tqdm( + range(encoder_output.size(0)), + desc='Beam search progress:', + total=encoder_output.size(0), + unit='sample', + ) as idx_gen: + # Decode every sample in the batch independently. + for batch_idx in idx_gen: + inseq = encoder_output[batch_idx : batch_idx + 1, :, :] # [1, T, D] + logitlen = encoded_lengths[batch_idx] + + # Execute the specific search strategy + nbest_hyps = self.search_algorithm(inseq, logitlen) # sorted list of hypothesis + + # Pack the result + if self.return_best_hypothesis: + best_hypothesis = nbest_hyps[0] # type: Hypothesis + else: + best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses + hypotheses.append(best_hypothesis) + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (hypotheses,) + + def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Sort hypotheses by score or score given sequence length. + + Args: + hyps: list of hypotheses + + Return: + hyps: sorted list of hypotheses + """ + if self.score_norm: + return sorted(hyps, key=lambda x: x.score / len(x.y_sequence), reverse=True) + else: + return sorted(hyps, key=lambda x: x.score, reverse=True) + + def greedy_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + """Greedy search implementation for transducer. + Generic case when beam size = 1. Results might differ slightly due to implementation details + as compared to `GreedyRNNTInfer` and `GreedyBatchRNNTInfer`. + + Args: + h: Encoded speech features (1, T_max, D_enc) + + Returns: + hyp: 1-best decoding results + """ + # Initialize zero state vectors + dec_state = self.decoder.initialize_state(h) + + # Construct initial hypothesis + hyp = Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state) + cache = {} + + # Initialize state and first token + y, state, _ = self.decoder.score_hypothesis(hyp, cache) + + for i in range(int(encoded_lengths)): + hi = h[:, i : i + 1, :] # [1, 1, D] + + not_blank = True + symbols_added = 0 + + while not_blank: + ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1] + ytu = ytu[0, 0, 0, :] # [V + 1] + + # max() requires float + if ytu.dtype != torch.float32: + ytu = ytu.float() + + logp, pred = torch.max(ytu, dim=-1) # [1, 1] + pred = pred.item() + + if pred == self.blank: + not_blank = False + else: + # Update state and current sequence + hyp.y_sequence.append(int(pred)) + hyp.score += float(logp) + hyp.dec_state = state + + # Compute next state and token + y, state, _ = self.decoder.score_hypothesis(hyp, cache) + symbols_added += 1 + + return [hyp] + + def default_beam_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + """Beam search implementation. + + Args: + x: Encoded speech features (1, T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + """ + # Initialize states + beam = min(self.beam_size, self.vocab_size) + beam_k = min(beam, (self.vocab_size - 1)) + blank_tensor = torch.tensor([self.blank], device=h.device, dtype=torch.long) + + # Precompute some constants for blank position + ids = list(range(self.vocab_size + 1)) + ids.remove(self.blank) + + # Used when blank token is first vs last token + if self.blank == 0: + index_incr = 1 + else: + index_incr = 0 + + # Initialize zero vector states + dec_state = self.decoder.initialize_state(h) + + # Initialize first hypothesis for the beam (blank) + kept_hyps = [Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state)] + cache = {} + + for i in range(int(encoded_lengths)): + hi = h[:, i : i + 1, :] # [1, 1, D] + hyps = kept_hyps + kept_hyps = [] + + while True: + max_hyp = max(hyps, key=lambda x: x.score) + hyps.remove(max_hyp) + + # update decoder state and get next score + y, state, lm_tokens = self.decoder.score_hypothesis(max_hyp, cache) # [1, 1, D] + + # get next token + ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1] + ytu = ytu[0, 0, 0, :] # [V + 1] + + # remove blank token before top k + top_k = ytu[ids].topk(beam_k, dim=-1) + + # Two possible steps - blank token or non-blank token predicted + ytu = ( + torch.cat((top_k[0], ytu[self.blank].unsqueeze(0))), + torch.cat((top_k[1] + index_incr, blank_tensor)), + ) + + # for each possible step + for logp, k in zip(*ytu): + # construct hypothesis for step + new_hyp = Hypothesis( + score=(max_hyp.score + float(logp)), + y_sequence=max_hyp.y_sequence[:], + dec_state=max_hyp.dec_state, + lm_state=max_hyp.lm_state, + ) + + # if current token is blank, dont update sequence, just store the current hypothesis + if k == self.blank: + kept_hyps.append(new_hyp) + else: + # if non-blank token was predicted, update state and sequence and then search more hypothesis + new_hyp.dec_state = state + new_hyp.y_sequence.append(int(k)) + + hyps.append(new_hyp) + + # keep those hypothesis that have scores greater than next search generation + hyps_max = float(max(hyps, key=lambda x: x.score).score) + kept_most_prob = sorted([hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score,) + + # If enough hypothesis have scores greater than next search generation, + # stop beam search. + if len(kept_most_prob) >= beam: + kept_hyps = kept_most_prob + break + + return self.sort_nbest(kept_hyps) + + def time_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + """Time synchronous beam search implementation. + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + h: Encoded speech features (1, T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + """ + # Precompute some constants for blank position + ids = list(range(self.vocab_size + 1)) + ids.remove(self.blank) + + # Used when blank token is first vs last token + if self.blank == 0: + index_incr = 1 + else: + index_incr = 0 + + # prepare the batched beam states + beam = min(self.beam_size, self.vocab_size) + beam_state = self.decoder.initialize_state( + torch.zeros(beam, device=h.device, dtype=h.dtype) + ) # [L, B, H], [L, B, H] (for LSTMs) + + # Initialize first hypothesis for the beam (blank) + B = [Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0))] + cache = {} + + for i in range(int(encoded_lengths)): + hi = h[:, i : i + 1, :] + + # Update caches + A = [] + C = B + + h_enc = hi + + # For a limited number of symmetric expansions per timestep "i" + for v in range(self.tsd_max_symmetric_expansion_per_step): + D = [] + + # Decode a batch of beam states and scores + beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(C, cache, beam_state) + + # Extract the log probabilities and the predicted tokens + beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B, 1, 1, V + 1] + beam_logp = beam_logp[:, 0, 0, :] # [B, V + 1] + beam_topk = beam_logp[:, ids].topk(beam, dim=-1) + + seq_A = [h.y_sequence for h in A] + + for j, hyp in enumerate(C): + # create a new hypothesis in A + if hyp.y_sequence not in seq_A: + # If the sequence is not in seq_A, add it as the blank token + # In this step, we dont add a token but simply update score + A.append( + Hypothesis( + score=(hyp.score + float(beam_logp[j, self.blank])), + y_sequence=hyp.y_sequence[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + ) + else: + # merge the existing blank hypothesis score with current score. + dict_pos = seq_A.index(hyp.y_sequence) + + A[dict_pos].score = np.logaddexp( + A[dict_pos].score, (hyp.score + float(beam_logp[j, self.blank])) + ) + + if v < self.tsd_max_symmetric_expansion_per_step: + for j, hyp in enumerate(C): + # for each current hypothesis j + # extract the top token score and top token id for the jth hypothesis + for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr): + # create new hypothesis and store in D + # Note: This loop does *not* include the blank token! + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + y_sequence=(hyp.y_sequence + [int(k)]), + dec_state=self.decoder.batch_select_state(beam_state, j), + lm_state=hyp.lm_state, + ) + + D.append(new_hyp) + + # Prune beam + C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] + + # Prune beam + B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] + + return self.sort_nbest(B) + + def align_length_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + """Alignment-length synchronous beam search implementation. + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + h: Encoded speech features (1, T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + """ + # Precompute some constants for blank position + ids = list(range(self.vocab_size + 1)) + ids.remove(self.blank) + + # Used when blank token is first vs last token + if self.blank == 0: + index_incr = 1 + else: + index_incr = 0 + + # prepare the batched beam states + beam = min(self.beam_size, self.vocab_size) + + h = h[0] # [T, D] + h_length = int(encoded_lengths) + beam_state = self.decoder.initialize_state( + torch.zeros(beam, device=h.device, dtype=h.dtype) + ) # [L, B, H], [L, B, H] for LSTMS + + # compute u_max as either a specific static limit, + # or a multiple of current `h_length` dynamically. + if type(self.alsd_max_target_length) == float: + u_max = int(self.alsd_max_target_length * h_length) + else: + u_max = int(self.alsd_max_target_length) + + # Initialize first hypothesis for the beam (blank) + B = [Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0))] + + final = [] + cache = {} + + # ALSD runs for T + U_max steps + for i in range(h_length + u_max): + # Update caches + A = [] + B_ = [] + h_states = [] + + # preserve the list of batch indices which are added into the list + # and those which are removed from the list + # This is necessary to perform state updates in the correct batch indices later + batch_ids = list(range(len(B))) # initialize as a list of all batch ids + batch_removal_ids = [] # update with sample ids which are removed + + for bid, hyp in enumerate(B): + u = len(hyp.y_sequence) - 1 + t = i - u + 1 + + if t > (h_length - 1): + batch_removal_ids.append(bid) + continue + + B_.append(hyp) + h_states.append((t, h[t])) + + if B_: + # Compute the subset of batch ids which were *not* removed from the list above + sub_batch_ids = None + if len(B_) != beam: + sub_batch_ids = batch_ids + for id in batch_removal_ids: + # sub_batch_ids contains list of ids *that were not removed* + sub_batch_ids.remove(id) + + # extract the states of the sub batch only. + beam_state_ = [beam_state[state_id][:, sub_batch_ids, :] for state_id in range(len(beam_state))] + else: + # If entire batch was used (none were removed), simply take all the states + beam_state_ = beam_state + + # Decode a batch/sub-batch of beam states and scores + beam_y, beam_state_, beam_lm_tokens = self.decoder.batch_score_hypothesis(B_, cache, beam_state_) + + # If only a subset of batch ids were updated (some were removed) + if sub_batch_ids is not None: + # For each state in the RNN (2 for LSTM) + for state_id in range(len(beam_state)): + # Update the current batch states with the sub-batch states (in the correct indices) + # These indices are specified by sub_batch_ids, the ids of samples which were updated. + beam_state[state_id][:, sub_batch_ids, :] = beam_state_[state_id][...] + else: + # If entire batch was updated, simply update all the states + beam_state = beam_state_ + + # h_states = list of [t, h[t]] + # so h[1] here is a h[t] of shape [D] + # Simply stack all of the h[t] within the sub_batch/batch (T <= beam) + h_enc = torch.stack([h[1] for h in h_states]) # [T=beam, D] + h_enc = h_enc.unsqueeze(1) # [B=beam, T=1, D]; batch over the beams + + # Extract the log probabilities and the predicted tokens + beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B=beam, 1, 1, V + 1] + beam_logp = beam_logp[:, 0, 0, :] # [B=beam, V + 1] + beam_topk = beam_logp[:, ids].topk(beam, dim=-1) + + for j, hyp in enumerate(B_): + # For all updated samples in the batch, add it as the blank token + # In this step, we dont add a token but simply update score + new_hyp = Hypothesis( + score=(hyp.score + float(beam_logp[j, self.blank])), + y_sequence=hyp.y_sequence[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + + # Add blank prediction to A + A.append(new_hyp) + + # If the prediction "timestep" t has reached the length of the input sequence + # we can add it to the "finished" hypothesis list. + if h_states[j][0] == (h_length - 1): + final.append(new_hyp) + + # Here, we carefully select the indices of the states that we want to preserve + # for the next token (non-blank) update. + if sub_batch_ids is not None: + h_states_idx = sub_batch_ids[j] + else: + h_states_idx = j + + # for each current hypothesis j + # extract the top token score and top token id for the jth hypothesis + for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr): + # create new hypothesis and store in A + # Note: This loop does *not* include the blank token! + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + y_sequence=(hyp.y_sequence[:] + [int(k)]), + dec_state=self.decoder.batch_select_state(beam_state, h_states_idx), + lm_state=hyp.lm_state, + ) + + A.append(new_hyp) + + # Prune and recombine same hypothesis + # This may cause next beam to be smaller than max beam size + # Therefore larger beam sizes may be required for better decoding. + B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] + B = self.recombine_hypotheses(B) + + # If B_ is empty list, then we may be able to early exit + elif len(batch_ids) == len(batch_removal_ids): + break + + if final: + return self.sort_nbest(final) + else: + return B + + def recombine_hypotheses(self, hypotheses: List[Hypothesis]) -> List[Hypothesis]: + """Recombine hypotheses with equivalent output sequence. + + Args: + hypotheses (list): list of hypotheses + + Returns: + final (list): list of recombined hypotheses + """ + final = [] + + for hyp in hypotheses: + seq_final = [f.y_sequence for f in final if f.y_sequence] + + if hyp.y_sequence in seq_final: + seq_pos = seq_final.index(hyp.y_sequence) + + final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score) + else: + final.append(hyp) + + return hypotheses + + def convtt_beam_search(self, encoder_outputs: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: + assert self.decoder.prepend_sos_label + start_token_id = self.decoder.sos_idx + + init_pred_net_state = self.decoder.initialize_state(1, encoder_outputs.dtype, encoder_outputs.device) + + kept_hyps = [Hyp(score=1.0, labels=tuple(), + pred_net_state=init_pred_net_state, + last_label=start_token_id)] + blank_label_id = self.decoder.blank_idx + assert start_token_id != blank_label_id + + pred_net_cache = {} + pred_net_cache_steps = 0 + pred_net_cache_missing_steps = 0 + + for t in range(int(encoded_lengths)): + blank_hyps = {} + exp_i = 0 + exp_hyps = kept_hyps + + while exp_i < self.beam_max_exp_step and len(exp_hyps) > 0: + new_hyps = [] + # expand hyps + assert len(exp_hyps) <= self.beam_size + for hyp_i in exp_hyps: + if hyp_i.last_label == blank_label_id: + assert hyp_i.pred_net_output is not None + pred_net_output = hyp_i.pred_net_output + pred_net_state = hyp_i.pred_net_state + else: + assert hyp_i.last_label == start_token_id or hyp_i.last_label == hyp_i.labels[-1] + if hyp_i.labels in pred_net_cache: + pred_net_cache_steps += 1 + pred_net_output, pred_net_state = pred_net_cache[hyp_i.labels] + else: + pred_net_cache_missing_steps += 1 + last_label_t = torch.tensor([[hyp_i.last_label]], dtype=torch.long, device=encoder_outputs.device) + pred_net_output, pred_net_state = self.decoder.predict(last_label_t, + hyp_i.pred_net_state, + add_sos=False) + pred_net_cache[hyp_i.labels] = (pred_net_output, pred_net_state) + # [B, T, U, V + 1] + logits = self.joint.joint(encoder_outputs[:, t:t+1, :], pred_net_output, apply_softmax=False) + if self.beam_temperature != 1.0: + logits = logits / self.beam_temperature + new_hyp_logp = torch.log_softmax(logits, dim=-1) + new_hyp_logp = new_hyp_logp.cpu().numpy() + # new_hyp_logp = new_hyp_logp.astype(np.float64) + new_hyp_logp += hyp_i.score + top_new_hyp_idx = np.argsort(new_hyp_logp) + assert top_new_hyp_idx.shape[:3] == (1, 1, 1) + top_new_hyp_idx = top_new_hyp_idx[0, 0, 0, -(self.beam_size + 1):] + for label_i in top_new_hyp_idx: + if label_i == blank_label_id: + continue + new_hyp = Hyp(score=new_hyp_logp[0, 0, 0, label_i], + labels=hyp_i.labels + (label_i,), + pred_net_state=pred_net_state, + pred_net_output=pred_net_output, + last_label=label_i) + new_hyps.append(new_hyp) + new_blank_hyp = Hyp(score=new_hyp_logp[0, 0, 0, blank_label_id], + labels=hyp_i.labels, + pred_net_state=pred_net_state, + pred_net_output=pred_net_output, + last_label=blank_label_id) + if self.beam_prune_exp: + new_hyps.append(new_blank_hyp) + if new_blank_hyp.labels in blank_hyps: + if self.beam_combine_path: + existing_blank_hyp = blank_hyps[new_blank_hyp.labels] + existing_blank_hyp.score = np.logaddexp(existing_blank_hyp.score, new_blank_hyp.score) + else: + if new_blank_hyp.score > blank_hyps[new_blank_hyp.labels].score: + blank_hyps[new_blank_hyp.labels] = new_blank_hyp + else: + blank_hyps[new_blank_hyp.labels] = new_blank_hyp + + if self.beam_prune_exp_full and exp_i > 0: + exp_beam_size = len(exp_hyps) + else: + exp_beam_size = self.beam_size + + # one expansion step for all beam finished, select top K hyp for further search + exp_hyps = get_top_k_hyps(new_hyps, exp_beam_size, ln_alpha=self.beam_stepwise_ln_alpha, + word_reward_ratio=self.beam_word_reward_ratio) + if self.beam_prune_exp: + # orig_exp_hyp_num = len(exp_hyps) + exp_hyps = [exp_hyp_i for exp_hyp_i in exp_hyps if exp_hyp_i.last_label != blank_label_id] + # print('pruned exp_hyps from {} to {}'.format(orig_exp_hyp_num, len(exp_hyps))) + exp_i += 1 + # expand finished + blank_hyps = list(blank_hyps.values()) + kept_hyps = get_top_k_hyps(blank_hyps, self.beam_size, always_sort=True, + ln_alpha=self.beam_stepwise_ln_alpha, + word_reward_ratio=self.beam_word_reward_ratio) + + if self.score_norm: + if self.beam_stepwise_ln_alpha > 0: + scores_len_norm = [hyp_i.length_norm_score(self.beam_stepwise_ln_alpha) for hyp_i in kept_hyps] + elif self.beam_word_reward_ratio > 0: + scores_len_norm = [hyp_i.word_reward_score(self.beam_word_reward_ratio) for hyp_i in kept_hyps] + else: + len_normed_best_hyp, scores_len_norm = get_length_normed_best(kept_hyps) + else: + scores_len_norm = [0] * len(kept_hyps) + + ret_hyps = [] + for i in range(len(kept_hyps)): + ret_hyps.append(Hypothesis(score=scores_len_norm[i] if self.score_norm else kept_hyps[i].score, + y_sequence=[int(label_i) for label_i in kept_hyps[i].labels])) + return sorted(ret_hyps, key=lambda x: x.score, reverse=True) + + +class Hyp: + __slots__ = ('score', 'labels', 'pred_net_state', 'pred_net_output', 'last_label') + + def __init__(self, score, labels=None, pred_net_state=None, pred_net_output=None, last_label=None): + self.score = score + self.labels = labels + self.pred_net_state = pred_net_state + self.pred_net_output = pred_net_output + self.last_label = last_label + + def length_norm_score(self, alpha): + return length_norm(self.score, len(self.labels), alpha) + + def word_reward_score(self, reward_ratio): + label_len = len(self.labels) + # if not self.blank: + # label_len -= 1 + return self.score + label_len * reward_ratio + + +def length_norm(log_prob, label_len, alpha): + len_norm = (label_len + 5.) / 6. + if alpha != 1: + len_norm = len_norm ** alpha + return log_prob / len_norm + + +def get_top_k_hyps(hyps, k, always_sort=False, ln_alpha=0., word_reward_ratio=0.): + if len(hyps) <= k and not always_sort: + return hyps + else: + if ln_alpha > 0: + hyp_scores = np.array([hyp_i.length_norm_score(ln_alpha) for hyp_i in hyps]) + elif word_reward_ratio > 0: + hyp_scores = np.array([hyp_i.word_reward_score(word_reward_ratio) for hyp_i in hyps]) + else: + hyp_scores = np.array([hyp_i.score for hyp_i in hyps]) + top_k_idx = np.argsort(hyp_scores)[-k:] + return [hyps[i] for i in top_k_idx] + + +def get_length_normed_best(hyps): + best_hyp = None + best_hyp_score = None + score_len_normed = [] + for hyp_i in hyps: + length_normed_score = hyp_i.score / (len(hyp_i.labels) + 1e-16) + score_len_normed.append(length_normed_score) + if best_hyp_score is None or length_normed_score > best_hyp_score: + best_hyp_score = length_normed_score + best_hyp = hyp_i + assert best_hyp is not None + return best_hyp, score_len_normed diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_greedy_decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_greedy_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..41b30d4c7e87907cbfe6c5e240ef8f68d7f87407 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_greedy_decoding.py @@ -0,0 +1,540 @@ +from typing import Optional, Union + +import torch + +from nemo.collections.asr.modules import rnnt_abstract, TransformerTDecoder +from nemo.collections.asr.parts import rnnt_utils +from nemo.collections.common.parts.rnn import label_collate +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType + + +class _GreedyTransformerTInfer(Typing): + """A greedy transducer decoder. + + Provides a common abstraction for sample level and batch level greedy decoding. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + max_symbols_per_step: Optional int. The maximum number of symbols that can be added + to a sequence in a single time step; if set to None then there is + no limit. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": NeuralType(elements_type=HypothesisType())} + + def __init__( + self, + decoder_model: TransformerTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + max_symbols_per_step: Optional[int] = None, + ): + super().__init__() + self.decoder = decoder_model + self.joint = joint_model + + self._blank_index = blank_index + self._SOS = blank_index # Start of single index + self.max_symbols = max_symbols_per_step + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.no_grad() + def _pred_step( + self, + label: Union[torch.Tensor, int], + hidden: Optional[torch.Tensor], + add_sos: bool = False, + batch_size: Optional[int] = None, + ) -> (torch.Tensor, torch.Tensor): + """ + Common prediction step based on the AbstractRNNTDecoder implementation. + + Args: + label: (int/torch.Tensor): Label or "Start-of-Signal" token. + hidden: (Optional torch.Tensor): RNN State vector + add_sos (bool): Whether to add a zero vector at the begging as "start of sentence" token. + batch_size: Batch size of the output tensor. + + Returns: + g: (B, U, H) if add_sos is false, else (B, U + 1, H) + hid: (h, c) where h is the final sequence hidden state and c is + the final cell state: + h (tensor), shape (L, B, H) + c (tensor), shape (L, B, H) + """ + if isinstance(label, torch.Tensor): + # label: [batch, 1] + if label.dtype != torch.long: + label = label.long() + + else: + # Label is an integer + if label == self._SOS: + assert not self.decoder.prepend_sos_label + return self.decoder.predict(None, hidden, add_sos=add_sos, batch_size=batch_size) + + label = label_collate([[label]]) + + # output: [B, 1, K] + return self.decoder.predict(label, hidden, add_sos=add_sos, batch_size=batch_size) + + def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None): + """ + Common joint step based on AbstractRNNTJoint implementation. + + Args: + enc: Output of the Encoder model. A torch.Tensor of shape [B, 1, H1] + pred: Output of the Decoder model. A torch.Tensor of shape [B, 1, H2] + log_normalize: Whether to log normalize or not. None will log normalize only for CPU. + + Returns: + logits of shape (B, T=1, U=1, V + 1) + """ + with torch.no_grad(): + logits = self.joint.joint(enc, pred) + + if log_normalize is None: + if not logits.is_cuda: # Use log softmax only if on CPU + logits = logits.log_softmax(dim=len(logits.shape) - 1) + else: + if log_normalize: + logits = logits.log_softmax(dim=len(logits.shape) - 1) + + return logits + + +class GreedyTransformerTInfer(_GreedyTransformerTInfer): + """A greedy transducer decoder. + + Sequence level greedy decoding, performed auto-repressively. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + max_symbols_per_step: Optional int. The maximum number of symbols that can be added + to a sequence in a single time step; if set to None then there is + no limit. + """ + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + max_symbols_per_step: Optional[int] = None, + ): + super().__init__( + decoder_model=decoder_model, + joint_model=joint_model, + blank_index=blank_index, + max_symbols_per_step=max_symbols_per_step, + ) + + @typecheck() + def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + encoder_output: A tensor of size (batch, features, timesteps). + encoded_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + self.decoder.eval() + self.joint.eval() + + with torch.no_grad(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + + hypotheses = [] + # Process each sequence independently + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] + sentence = self._greedy_decode(inseq, logitlen) + hypotheses.append(sentence) + + # Pack results into Hypotheses + packed_result = [ + rnnt_utils.Hypothesis(y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0) + for sent in hypotheses + ] + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (packed_result,) + + @torch.no_grad() + def _greedy_decode(self, x: torch.Tensor, out_len: torch.Tensor): + # x: [T, 1, D] + # out_len: [seq_len] + + hidden = self.decoder.initialize_state(x.size(1), x.dtype, x.device) + if not self.decoder.prepend_sos_label: + # Initialize blank state and empty label set + label = [] + else: + label = [self.decoder.sos_idx] + + + # For timestep t in X_t + for time_idx in range(out_len): + # Extract encoder embedding at timestep t + # f = x[time_idx, :, :].unsqueeze(0) # [1, 1, D] + f = x.narrow(dim=0, start=time_idx, length=1) + + # Setup exit flags and counter + not_blank = True + symbols_added = 0 + + # While blank is not predicted, or we dont run out of max symbols per timestep + while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): + # In the first timestep, we initialize the network with TransformerT Blank + # In later timesteps, we provide previous predicted label as input. + last_label = self._SOS if label == [] else label[-1] + + # Perform prediction network and joint network steps. + g, hidden_prime = self._pred_step(last_label, hidden) + logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :] + + del g + + # torch.max(0) op doesnt exist for FP 16. + if logp.dtype != torch.float32: + logp = logp.float() + + # get index k, of max prob + v, k = logp.max(0) + k = k.item() # K is the label at timestep t_s in inner loop, s >= 0. + + del logp + + # If blank token is predicted, exit inner loop, move onto next timestep t + if k == self._blank_index: + not_blank = False + else: + # Append token to label set, update RNN state. + label.append(k) + hidden = hidden_prime + + # Increment token counter. + symbols_added += 1 + + return label + + +class GreedyBatchedTransformerTInfer(_GreedyTransformerTInfer): + """A batch level greedy transducer decoder. + + Batch level greedy decoding, performed auto-repressively. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + max_symbols_per_step: Optional int. The maximum number of symbols that can be added + to a sequence in a single time step; if set to None then there is + no limit. + """ + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + blank_index: int, + max_symbols_per_step: Optional[int] = None, + ): + super().__init__( + decoder_model=decoder_model, + joint_model=joint_model, + blank_index=blank_index, + max_symbols_per_step=max_symbols_per_step, + ) + + # Depending on availability of `blank_as_pad` support + # switch between more efficient batch decoding technique + if self.decoder.blank_as_pad: + self._greedy_decode = self._greedy_decode_blank_as_pad + else: + self._greedy_decode = self._greedy_decode_masked + + @typecheck() + def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + encoder_output: A tensor of size (batch, features, timesteps). + encoded_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.no_grad(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + logitlen = encoded_lengths + + self.decoder.eval() + self.joint.eval() + + with self.decoder.as_frozen(), self.joint.as_frozen(): + inseq = encoder_output # [B, T, D] + hypotheses = self._greedy_decode(inseq, logitlen, device=inseq.device) + + # Pack the hypotheses results + packed_result = [ + rnnt_utils.Hypothesis(y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0) + for sent in hypotheses + ] + + del hypotheses + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (packed_result,) + + def _greedy_decode_blank_as_pad(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device): + with torch.no_grad(): + # x: [B, T, D] + # out_len: [B] + # device: torch.device + + # Initialize state + hidden = None + batchsize = x.shape[0] + + # Output string buffer + label = [[] for _ in range(batchsize)] + + # Last Label buffer + Last Label without blank buffer + # batch level equivalent of the last_label + last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) + + # Mask buffers + blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) + + # Get max sequence length + max_out_len = out_len.max() + + for time_idx in range(max_out_len): + f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] + + # Prepare t timestamp batch variables + not_blank = True + symbols_added = 0 + + # Reset blank mask + blank_mask.mul_(False) + + # Update blank mask with time mask + # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) + # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len + blank_mask = time_idx >= out_len + + # Start inner loop + while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): + + # Batch prediction and joint network steps + # If very first prediction step, submit SOS tag (blank) to pred_step. + # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state + if time_idx == 0 and symbols_added == 0: + g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) + else: + # Perform batch step prediction of decoder, getting new states and scores ("g") + g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize) + + # Batched joint step - Output = [B, V + 1] + logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :] + + if logp.dtype != torch.float32: + logp = logp.float() + + # Get index k, of max prob for batch + v, k = logp.max(1) + del v, g, logp + + # Update blank mask with current predicted blanks + # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) + k_is_blank = k == self._blank_index + blank_mask |= k_is_blank + + del k_is_blank + + # If all samples predict / have predicted prior blanks, exit loop early + # This is equivalent to if single sample predicted k + if blank_mask.all(): + not_blank = False + else: + # Collect batch indices where blanks occurred now/past + blank_indices = [] + if hidden is not None: + blank_indices = (blank_mask == 1).nonzero(as_tuple=False) + + # Recover prior state for all samples which predicted blank now/past + if hidden is not None: + # LSTM has 2 states + for state_id in range(len(hidden)): + hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :] + + # Recover prior predicted label for all samples which predicted blank now/past + k[blank_indices] = last_label[blank_indices, 0] + + # Update new label and hidden state for next iteration + last_label = k.clone().view(-1, 1) + hidden = hidden_prime + + # Update predicted labels, accounting for time mask + # If blank was predicted even once, now or in the past, + # Force the current predicted label to also be blank + # This ensures that blanks propogate across all timesteps + # once they have occured (normally stopping condition of sample level loop). + for kidx, ki in enumerate(k): + if blank_mask[kidx] == 0: + label[kidx].append(ki) + + symbols_added += 1 + + return label + + @torch.no_grad() + def _greedy_decode_masked(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device): + # x: [B, T, D] + # out_len: [B] + # device: torch.device + + # Initialize state + hidden = None + batchsize = x.shape[0] + + # Output string buffer + label = [[] for _ in range(batchsize)] + + # Last Label buffer + Last Label without blank buffer + # batch level equivalent of the last_label + last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) + last_label_without_blank = last_label.clone() + + # Mask buffers + blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) + + # Get max sequence length + max_out_len = out_len.max() + for time_idx in range(max_out_len): + f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] + + # Prepare t timestamp batch variables + not_blank = True + symbols_added = 0 + + # Reset blank mask + blank_mask.mul_(False) + + # Update blank mask with time mask + # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) + # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len + blank_mask = time_idx >= out_len + + # Start inner loop + while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): + # Batch prediction and joint network steps + # If very first prediction step, submit SOS tag (blank) to pred_step. + # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state + if time_idx == 0 and symbols_added == 0: + g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) + else: + # Set a dummy label for the blank value + # This value will be overwritten by "blank" again the last label update below + # This is done as vocabulary of prediction network does not contain "blank" token of TransformerT + last_label_without_blank_mask = last_label == self._blank_index + last_label_without_blank[last_label_without_blank_mask] = 0 # temp change of label + last_label_without_blank[~last_label_without_blank_mask] = last_label[ + ~last_label_without_blank_mask + ] + + # Perform batch step prediction of decoder, getting new states and scores ("g") + g, hidden_prime = self._pred_step(last_label_without_blank, hidden, batch_size=batchsize) + + # Batched joint step - Output = [B, V + 1] + logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :] + + if logp.dtype != torch.float32: + logp = logp.float() + + # Get index k, of max prob for batch + v, k = logp.max(1) + del v, g, logp + + # Update blank mask with current predicted blanks + # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) + k_is_blank = k == self._blank_index + blank_mask.bitwise_or_(k_is_blank) + + # If all samples predict / have predicted prior blanks, exit loop early + # This is equivalent to if single sample predicted k + if blank_mask.all(): + not_blank = False + else: + # Collect batch indices where blanks occurred now/past + blank_indices = [] + if hidden is not None: + blank_indices = (blank_mask == 1).nonzero(as_tuple=False) + + # Recover prior state for all samples which predicted blank now/past + if hidden is not None: + # LSTM has 2 states + for state_id in range(len(hidden)): + hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :] + + # Recover prior predicted label for all samples which predicted blank now/past + k[blank_indices] = last_label[blank_indices, 0] + + # Update new label and hidden state for next iteration + last_label = k.view(-1, 1) + hidden = hidden_prime + + # Update predicted labels, accounting for time mask + # If blank was predicted even once, now or in the past, + # Force the current predicted label to also be blank + # This ensures that blanks propogate across all timesteps + # once they have occured (normally stopping condition of sample level loop). + for kidx, ki in enumerate(k): + if blank_mask[kidx] == 0: + label[kidx].append(ki) + + symbols_added += 1 + + return label diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/wav2vec.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..708398da76ed4c3e9d1c060002cfa1058a661f8d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/wav2vec.py @@ -0,0 +1,458 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List, Tuple + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from nemo.collections.asr.models.wav2vec.wav2vec_config import Wav2VecConvExtractorMode, Wav2VecTransformerConfig +from nemo.collections.asr.parts.layer_norm import LayerNorm +from nemo.collections.asr.parts.multihead_attention import MultiheadAttention + + +class TransposeLast(torch.nn.Module): + """ + Transposes last dimension. Useful for adding to a sequential block. + """ + + def forward(self, x): + return x.transpose(-2, -1) + + +class SamePad(torch.nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = kernel_size % 2 == 0 + + def forward(self, x): + if self.remove: + x = x[:, :, :-1] + return x + + +class ConvFeatureEncoder(nn.Module): + """ + Converts input raw audio into features for downstream transformer model. + Uses 1D convolutional blocks with GeLU activation. + """ + + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + mode: Wav2VecConvExtractorMode = Wav2VecConvExtractorMode.default, + conv_bias: bool = False, + ): + super().__init__() + + def block( + n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Sequential(TransposeLast(), nn.LayerNorm(dim, elementwise_affine=True), TransposeLast()), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential(make_conv(), nn.GroupNorm(dim, dim, affine=True), nn.GELU(),) + else: + return nn.Sequential(make_conv(), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode is Wav2VecConvExtractorMode.layer_norm, + is_group_norm=mode is Wav2VecConvExtractorMode.default and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + # BxT -> BxCxT + x = x.unsqueeze(1) + for conv in self.conv_layers: + x = conv(x) + return x + + def get_subsampled_lens(self, lens): + for m in self.conv_layers: + conv = m[0] + lens = (lens + 2 * conv.padding[0] - conv.dilation[0] * (conv.kernel_size[0] - 1) - 1) // conv.stride[0] + 1 + return lens + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + conv_cfg = args.conv + + self.dropout = args.dropout + self.embedding_dim = args.encoder.embedding_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=conv_cfg.conv_pos, + padding=conv_cfg.conv_pos // 2, + groups=conv_cfg.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (conv_cfg.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(conv_cfg.conv_pos), nn.GELU()) + self.pos_conv_layer_drop = conv_cfg.layer_drop + + encoder_cfg = args.encoder + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=encoder_cfg.ffn_embedding_dim, + num_attention_heads=encoder_cfg.num_attention_heads, + dropout=self.dropout, + attention_dropout=encoder_cfg.attention_dropout, + activation_dropout=encoder_cfg.activation_dropout, + activation_fn=encoder_cfg.activation_fn.value, + layer_norm_first=encoder_cfg.layer_norm_first, + ) + for _ in range(encoder_cfg.encoder_layers) + ] + ) + + self.layer_norm_first = encoder_cfg.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = encoder_cfg.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None): + x = self.extract_features(x, padding_mask) + + if self.layer_norm_first: + x = self.layer_norm(x) + + return x + + def extract_features(self, x, padding_mask=None): + + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + if self.pos_conv_layer_drop > 0 and self.training and np.random.random() < self.pos_conv_layer_drop: + pass + else: + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # layer_results = [] + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability >= self.layerdrop): + x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) + # layer_results.append(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn + + +def index_put(tensor, indices, value): + tensor[indices] = value + return tensor + + +def get_activation_fn(activation: str): + """ Returns the activation function corresponding to `activation` """ + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +class Wav2VecTransformerEncoder(nn.Module): + def __init__(self, cfg: Wav2VecTransformerConfig): + super().__init__() + + conv_cfg = cfg.conv + + self.dropout = cfg.dropout + self.embedding_dim = cfg.encoder.embedding_dim + self.layer_norm_first = cfg.encoder.layer_norm_first + assert not self.layer_norm_first, 'nn.TransformerEncoderLayer do not support layer_norm_first' + + # positional convolutional embeddings + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=conv_cfg.conv_pos, + padding=conv_cfg.conv_pos // 2, + groups=conv_cfg.conv_pos_groups, + ) + + self.feature_dropout = nn.Dropout(self.dropout) + + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (conv_cfg.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(conv_cfg.conv_pos), nn.GELU()) + + encoder_cfg = cfg.encoder + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer=nn.TransformerEncoderLayer( + d_model=self.embedding_dim, + nhead=encoder_cfg.num_attention_heads, + dim_feedforward=encoder_cfg.ffn_embedding_dim, + dropout=self.dropout, + activation=encoder_cfg.activation_fn.value, + ), + num_layers=encoder_cfg.encoder_layers, + ) + self.layer_norm = nn.LayerNorm(self.embedding_dim) + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None): + x = self.extract_features(x, padding_mask) + + if self.layer_norm_first: + x = self.layer_norm(x) + + return x + + def extract_features(self, x, padding_mask=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = self.feature_dropout(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.transformer_encoder(x, src_key_padding_mask=padding_mask) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bias will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + if isinstance(module, nn.TransformerEncoderLayer): + normal_(module.self_attn.in_proj_weight.data) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52a45b36ffa09b87165240ccc6c176baee2af94e --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo.collections.common.callbacks +from nemo.collections.common import losses, parts, tokenizers +from nemo.package_info import __version__ + +# Set collection version equal to NeMo version. +__version = __version__ + +# Authorship. +__author__ = "NVIDIA Corporation" + +# Set collection name. +__description__ = "Common collection" diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad5c9c85a5f1c5921c115112e313b86a1e123a2 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/callbacks.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..55fa5c50a1c54fa73ec2f8f2ef5b3fbfe96016e4 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/callbacks.py @@ -0,0 +1,32 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import rank_zero_only + + +class LogEpochTimeCallback(Callback): + """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log + """ + + @rank_zero_only + def on_epoch_start(self, trainer, pl_module): + self.epoch_start = time.time() + + @rank_zero_only + def on_epoch_end(self, trainer, pl_module): + curr_time = time.time() + duration = curr_time - self.epoch_start + trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7d3b0c1a8cc214d04b2b814f24f7d21ebd251d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.common.losses.aggregator import AggregatorLoss +from nemo.collections.common.losses.cross_entropy import CrossEntropyLoss +from nemo.collections.common.losses.mse_loss import MSELoss +from nemo.collections.common.losses.smoothed_cross_entropy import SmoothedCrossEntropyLoss +from nemo.collections.common.losses.spanning_loss import SpanningLoss diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/aggregator.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..1987ddd22bae30d6455adf87a7cc4313d9cdff1c --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/aggregator.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types import LossType, NeuralType + +__all__ = ['AggregatorLoss'] + + +class AggregatorLoss(Loss): + """ + Sums several losses into one. + + Args: + num_inputs: number of input losses + weights: a list of coefficient for merging losses + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + input_types = {} + for i in range(self._num_losses): + input_types["loss_" + str(i + 1)] = NeuralType(elements_type=LossType()) + + return input_types + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, num_inputs: int = 2, weights: List[float] = None): + super().__init__() + self._num_losses = num_inputs + if weights is not None and len(weights) != num_inputs: + raise ValueError("Length of weights should be equal to the number of inputs (num_inputs)") + + self._weights = weights + + @typecheck() + def forward(self, **kwargs): + values = [kwargs[x] for x in sorted(kwargs.keys())] + loss = torch.zeros_like(values[0]) + for loss_idx, loss_value in enumerate(values): + if self._weights is not None: + loss = loss.add(loss_value, alpha=self._weights[loss_idx]) + else: + loss = loss.add(loss_value) + return loss diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/cross_entropy.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ed0a6b3d81929e8a544de6140d45435ce7440d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/cross_entropy.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from nemo.core.classes import Serialization, Typing, typecheck +from nemo.core.neural_types import LabelsType, LogitsType, LossType, MaskType, NeuralType + +__all__ = ['CrossEntropyLoss'] + + +class CrossEntropyLoss(nn.CrossEntropyLoss, Serialization, Typing): + """ + CrossEntropyLoss + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "logits": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 1), LogitsType()), + "labels": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 2), LabelsType()), + "loss_mask": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 2), MaskType(), optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, logits_ndim=2, weight=None, reduction='mean'): + """ + Args: + logits_ndim (int): number of dimensions (or rank) of the logits tensor + weight (list): list of rescaling weight given to each class + reduction (str): type of the reduction over the batch + """ + if weight is not None and not torch.is_tensor(weight): + weight = torch.FloatTensor(weight) + super().__init__(weight=weight, reduction=reduction) + self._logits_dim = logits_ndim + + @typecheck() + def forward(self, logits, labels, loss_mask=None): + """ + Args: + logits (float): output of the classifier + labels (long): ground truth labels + loss_mask (bool/float/int): tensor to specify the masking + """ + logits_flatten = torch.flatten(logits, start_dim=0, end_dim=-2) + labels_flatten = torch.flatten(labels, start_dim=0, end_dim=-1) + + if loss_mask is not None: + if loss_mask.dtype is not torch.bool: + loss_mask = loss_mask > 0.5 + loss_mask_flatten = torch.flatten(loss_mask, start_dim=0, end_dim=-1) + logits_flatten = logits_flatten[loss_mask_flatten] + labels_flatten = labels_flatten[loss_mask_flatten] + + if len(labels_flatten) == 0: + return super().forward(logits, torch.argmax(logits, dim=-1)) + + loss = super().forward(logits_flatten, labels_flatten) + return loss diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/mse_loss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/mse_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..802e8ca49204a9bf9d439b114d4f3315d453e00f --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/mse_loss.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import Tensor, nn + +from nemo.core.classes import Serialization, Typing, typecheck +from nemo.core.neural_types import LabelsType, LossType, NeuralType, RegressionValuesType + +__all__ = ['MSELoss'] + + +class MSELoss(nn.MSELoss, Serialization, Typing): + """ + MSELoss + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "preds": NeuralType(tuple('B'), RegressionValuesType()), + "labels": NeuralType(tuple('B'), LabelsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__(self, reduction: str = 'mean'): + """ + Args: + reduction: type of the reduction over the batch + """ + super().__init__(reduction=reduction) + + @typecheck() + def forward(self, preds: Tensor, labels: Tensor) -> Tensor: + """ + Args: + preds: output of the classifier + labels: ground truth labels + """ + return super().forward(preds, labels) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/smoothed_cross_entropy.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/smoothed_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..8e37d29ba45347d705c19833ce8dd2da60acde04 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/smoothed_cross_entropy.py @@ -0,0 +1,103 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types import LabelsType, LogprobsType, LossType, MaskType, NeuralType + +__all__ = ["SmoothedCrossEntropyLoss"] + + +class SmoothedCrossEntropyLoss(Loss): + """ + Calculates Cross-entropy loss with label smoothing for a batch of sequences. + + SmoothedCrossEntropyLoss: + 1) excludes padding tokens from loss calculation + 2) allows to use label smoothing regularization + 3) allows to calculate loss for the desired number of last tokens + + Args: + label_smoothing (float): label smoothing regularization coefficient + predict_last_k (int): parameter which sets the number of last tokens to calculate the loss for, for example + 0: (default) calculate loss on the entire sequence (e.g., NMT) + 1: calculate loss on the last token only (e.g., LM evaluation) + Intermediate values allow to control the trade-off between eval + time (proportional to the number of batches) and eval performance + (proportional to the number of context tokens) + pad_id (int): padding id + eps (float): the small eps number to avoid division buy zero + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "log_probs": NeuralType(("B", "T", "D"), LogprobsType()), + "labels": NeuralType(("B", "T"), LabelsType()), + "output_mask": NeuralType(("B", "T"), MaskType(), optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"loss": NeuralType(elements_type=LossType())} + + def __init__( + self, + pad_id: Optional[int] = None, + label_smoothing: Optional[float] = 0.0, + predict_last_k: Optional[int] = 0, + eps: float = 1e-6, + ): + super().__init__() + self._pad_id = pad_id + self._eps = eps + self._predict_last_k = predict_last_k + self._label_smoothing = label_smoothing + + @typecheck() + def forward(self, log_probs, labels, output_mask=None): + """ + Args: + log_probs: float tensor of shape batch_size x seq_len x vocab_size, values should be log probabilities + labels: int tensor of shape batch_size x seq_len + output_mask: binary tensor of shape batch_size x seq_len + eps: epsilon param to avoid divide by zero in loss calculation + """ + if output_mask is None and self._pad_id is None: + raise ValueError("Both output_mask and pad_id are None") + if output_mask is None and self._pad_id is not None: + output_mask = (labels != self._pad_id).to(log_probs.dtype) + + if output_mask.dtype is not log_probs.dtype: + output_mask = output_mask.to(log_probs.dtype) + + batch_size, seq_len, vocab_size = log_probs.size() + smoothing = vocab_size * self._label_smoothing / (vocab_size - 1) + target_log_probs = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2) + + smoothing_log_probs = log_probs.mean(dim=-1) + neg_log_likelihood = (1.0 - smoothing) * target_log_probs + smoothing * smoothing_log_probs + neg_log_likelihood = neg_log_likelihood[:, -self._predict_last_k :] + output_mask = output_mask[:, -self._predict_last_k :] + neg_log_likelihood = -torch.sum(neg_log_likelihood * output_mask) + neg_log_likelihood = neg_log_likelihood / (output_mask.sum() + self._eps) + + return neg_log_likelihood diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/spanning_loss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/spanning_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a12dab64afd00f83003e8fb9fa25af3d72a360c7 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/spanning_loss.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn + +from nemo.core.classes import Loss, typecheck +from nemo.core.neural_types import ChannelType, LogitsType, LossType, NeuralType + +__all__ = ['SpanningLoss'] + + +class SpanningLoss(Loss): + """ + implements start and end loss of a span e.g. for Question Answering. + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + return { + "logits": NeuralType(('B', 'T', 'D'), LogitsType()), + "start_positions": NeuralType(tuple('B'), ChannelType()), + "end_positions": NeuralType(tuple('B'), ChannelType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return { + "loss": NeuralType(elements_type=LossType()), + "start_logits": NeuralType(('B', 'T'), LogitsType()), + "end_logits": NeuralType(('B', 'T'), LogitsType()), + } + + def __init__(self,): + super().__init__() + + @typecheck() + def forward(self, logits, start_positions, end_positions): + """ + Args: + logits: Output of question answering head, which is a token classfier. + start_positions: Ground truth start positions of the answer w.r.t. + input sequence. If question is unanswerable, this will be + pointing to start token, e.g. [CLS], of the input sequence. + end_positions: Ground truth end positions of the answer w.r.t. + input sequence. If question is unanswerable, this will be + pointing to start token, e.g. [CLS], of the input sequence. + """ + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + return total_loss, start_logits, end_logits diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2b312b6f771b40a83db2f6adec8af033fc331c --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.common.metrics.classification_accuracy import TopKClassificationAccuracy +from nemo.collections.common.metrics.perplexity import Perplexity diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/classification_accuracy.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/classification_accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..4e31fa9bd107a22f762c510633633740f032c6f4 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/classification_accuracy.py @@ -0,0 +1,152 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from pytorch_lightning.metrics import Metric + +__all__ = ['TopKClassificationAccuracy'] + + +class TopKClassificationAccuracy(Metric): + """ + This metric computes numerator and denominator for Overall Accuracy between logits and labels. + When doing distributed training/evaluation the result of res=TopKClassificationAccuracy(logits, labels) calls + will be all-reduced between all workers using SUM operations. + Here contains two numbers res=[correctly_predicted, total_samples]. Accuracy=correctly_predicted/total_samples. + + If used with PytorchLightning LightningModule, include correct_count and total_count inside validation_step results. + Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. + + Example: + def validation_step(self, batch, batch_idx): + ... + correct_count, total_count = self._accuracy(logits, labels) + return {'val_loss': loss_value, 'val_correct_count': correct_count, 'val_total_count': total_count} + + def validation_epoch_end(self, outputs): + ... + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + correct_counts = torch.stack([x['val_correct_counts'] for x in outputs]) + total_counts = torch.stack([x['val_total_counts'] for x in outputs]) + + topk_scores = compute_topk_accuracy(correct_counts, total_counts) + + tensorboard_log = {'val_loss': val_loss_mean} + for top_k, score in zip(self._accuracy.top_k, topk_scores): + tensorboard_log['val_epoch_top@{}'.format(top_k)] = score + + return {'log': tensorboard_log} + + Args: + top_k: Optional list of integers. Defaults to [1]. + + Returns: + res: a torch.Tensor object with two elements: [correct_count, total_count]. To correctly compute average + accuracy, compute acc=correct_count/total_count + """ + + def __init__(self, top_k=None, dist_sync_on_step=False): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + if top_k is None: + top_k = [1] + + self.top_k = top_k + self.add_state( + "correct_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False + ) + self.add_state("total_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False) + + @torch.no_grad() + def top_k_predicted_labels(self, logits: torch.Tensor) -> torch.Tensor: + max_k = max(self.top_k) + _, predictions = logits.topk(max_k, dim=1, largest=True, sorted=True) + return predictions + + def update(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + predictions = self.top_k_predicted_labels(logits) + predictions = predictions.t() + correct = predictions.eq(labels.view(1, -1)).expand_as(predictions) + + correct_counts_k = [] + total_counts_k = [] + + for k in self.top_k: + correct_k = correct[:k].reshape(-1).float().sum() + total_k = labels.shape[0] + + correct_counts_k.append(correct_k) + total_counts_k.append(total_k) + + self.correct_counts_k = torch.tensor(correct_counts_k, dtype=labels.dtype, device=labels.device) + self.total_counts_k = torch.tensor(total_counts_k, dtype=labels.dtype, device=labels.device) + + def compute(self): + """ + Computes the top-k accuracy. + + Returns: + A list of length `K`, such that k-th index corresponds to top-k accuracy + over all distributed processes. + """ + if not len(self.correct_counts_k) == len(self.top_k) == len(self.total_counts_k): + raise ValueError("length of counts must match to topk length") + + if self.top_k == [1]: + return [self.correct_counts_k.float() / self.total_counts_k] + + else: + top_k_scores = compute_topk_accuracy(self.correct_counts_k, self.total_counts_k) + + return top_k_scores + + @property + def top_k(self) -> List[int]: + return self._top_k + + @top_k.setter + def top_k(self, value: List[int]): + if value is None: + value = [1] + + if type(value) == int: + value = [value] + + if type(value) != list: + value = list(value) + + self._top_k = value + + +def compute_topk_accuracy(correct_counts_k, total_counts_k): + """ + Computes the top-k accuracy + Args: + correct_counts: Tensor of shape [K], K being the top-k parameter. + total_counts: Tensor of shape [K], and K being the top-k parameter. + Returns: + A list of length `K`, such that k-th index corresponds to top-k accuracy + over all distributed processes. + """ + top_k_scores = [] + + for ki in range(len(correct_counts_k)): + correct_count = correct_counts_k[ki].item() + total_count = total_counts_k[ki].item() + top_k_scores.append(correct_count / float(total_count)) + + return top_k_scores diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/perplexity.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d093cee1288a67671e3a6a826778641c0abf94 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/perplexity.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from pytorch_lightning.metrics import Metric +from torch.distributions.categorical import Categorical + +__all__ = ['Perplexity'] + + +class Perplexity(Metric): + """ + This class computes mean perplexity of distributions in the last dimension of inputs. It is a wrapper around + :doc:`torch.distributions.Categorical.perplexity` method. You have to provide either + ``probs`` or ``logits`` to the :meth:`update` method. The class computes perplexities for distributions passed to + :meth:`update` method in ``probs`` or ``logits`` arguments and averages the perplexities. Reducing results between + all workers is done via SUM operations. + + See :doc:`PyTorch Lightning Metrics` for the metric usage instructions. + + Args: + compute_on_step: + Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True`` + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: ``None`` (which selects the entire + world) + validate_args: + If ``True`` values of :meth:`update` method parameters are checked. ``logits`` has to not contain NaNs and + ``probs`` last dim has to be valid probability distribution. + """ + + def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, validate_args=True): + super().__init__( + compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group + ) + self.validate_args = validate_args + self.add_state('perplexities_sum', torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum') + # Total number of distributions seen since last reset + self.add_state('num_distributions', torch.tensor(0, dtype=torch.int64), dist_reduce_fx='sum') + + def update(self, probs=None, logits=None): + """ + Updates :attr:`perplexities_sum` and :attr:`num_distributions`. + + Args: + probs: A ``torch.Tensor`` which innermost dimension is valid probability distribution. + logits: A ``torch.Tensor`` without NaNs. + """ + d = Categorical(probs, logits, validate_args=self.validate_args) + ppl = d.perplexity() + self.num_distributions += ppl.numel() + self.perplexities_sum += ppl.sum() + + def compute(self): + """ + Returns perplexity across all workers and resets to 0 :attr:`perplexities_sum` and :attr:`num_distributions`. + """ + if self.num_distributions.eq(0): + return None + return self.perplexities_sum / self.num_distributions diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f3916e3655d5aa32650d6189e580431d9b8bc32 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.common.parts.multi_layer_perceptron import MultiLayerPerceptron +from nemo.collections.common.parts.transformer_utils import * +from nemo.collections.common.parts.utils import * diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/mem_transformer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/mem_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0197e4836af100044e805a72ba734332da5395 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/mem_transformer.py @@ -0,0 +1,531 @@ +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from nemo.collections.common.parts.normalization import LayerVarNorm + +float32_min = np.finfo(np.float32).min +float16_min = np.finfo(np.float16).min + + +class PositionalEmbedding(nn.Module): + def __init__(self, demb): + super(PositionalEmbedding, self).__init__() + + self.demb = demb + + inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, pos_seq, bsz=None): + sinusoid_inp = torch.ger(pos_seq, self.inv_freq) + pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) + + if bsz is not None: + return pos_emb[:,None,:].expand(-1, bsz, -1) + else: + return pos_emb[:,None,:] + + +def get_norm(norm_type, d_model, ln_eps): + if norm_type == 'ln': + norm = nn.LayerNorm(d_model, eps=ln_eps) + else: + assert norm_type == 'var_ln' + norm = LayerVarNorm(d_model, eps=ln_eps) + return norm + + +class PositionwiseFF(nn.Module): + def __init__(self, d_model, d_inner, *, dropout, pre_lnorm=False, ln_eps=1e-5, norm_type='ln'): + super(PositionwiseFF, self).__init__() + + self.d_model = d_model + self.d_inner = d_inner + self.dropout = dropout + + self.CoreNet = nn.Sequential( + nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout), + ) + + self.layer_norm = get_norm(norm_type, d_model, ln_eps) + + self.pre_lnorm = pre_lnorm + + def forward(self, inp): + if self.pre_lnorm: + ##### layer normalization + positionwise feed-forward + core_out = self.CoreNet(self.layer_norm(inp)) + + ##### residual connection + output = core_out + inp + else: + ##### positionwise feed-forward + core_out = self.CoreNet(inp) + + ##### residual connection + layer normalization + output = self.layer_norm(inp + core_out) + + return output + + +class AdaptiveFFN(nn.Module): + def __init__(self, d_model, d_inner, *, dropout, ln_eps=1e-5, gate_mix_prob=0.5, gate_sample_prob=0.25, + identity_loss_weight=1.0, identity_threshold=0.9, init_identiy_bias=2.0, + ffn_residual=True, norm_in_ffn=True): + super(AdaptiveFFN, self).__init__() + + self.ffn_net = nn.Sequential( + nn.Linear(d_model, d_inner), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout), + ) + self.ffn_residual = ffn_residual + + self.layer_norm = nn.LayerNorm(d_model, eps=ln_eps) + self.norm_in_ffn = norm_in_ffn + + self.gate_net = GateNet(d_model, gate_num=2, init_identiy_bias=init_identiy_bias) + self.gate_mix_prob = gate_mix_prob + self.gate_sample_prob = gate_sample_prob + assert 0 <= self.gate_mix_prob <= 1.0 + assert 0 <= self.gate_sample_prob <= 1.0 + assert 0 <= self.gate_mix_prob + self.gate_sample_prob <= 1.0 + self.identity_threshold = identity_threshold + + def forward(self, inputs, *, pad_mask): + # [T, B, D] => [T, B, gate_num] + gate_prob = self.gate_net(inputs) + # paddings may leads to NAN after softmax + # gate_prob = gate_prob.masked_fill(pad_mask.unsqueeze(2), 0.) + identity_prob, ffn_prob = torch.chunk(gate_prob, 2, dim=-1) + + ffn_output = self._ffn_forward(inputs) + + if self.training: + r = random.random() + if r < self.gate_mix_prob: + # learn gate + output = inputs * identity_prob + ffn_output * ffn_prob + adaptive_prob = identity_prob + elif r < self.gate_mix_prob + self.gate_sample_prob: + # exploit + identity_mask = torch.bernoulli(identity_prob) + output = inputs * identity_mask + ffn_output * (1 - identity_mask) + adaptive_prob = None + else: + # explore, by uniform sample branches + mask_size = (inputs.shape[0], inputs.shape[1], 1) + identity_mask = torch.bernoulli(torch.full(mask_size, 0.5, device=inputs.device)).type(inputs.dtype) + output = inputs * identity_mask + ffn_output * (1 - identity_mask) + adaptive_prob = None + else: + identity_mask = identity_prob > self.identity_threshold + output = inputs * identity_mask + ffn_output * ~identity_mask + adaptive_prob = identity_prob + + if not self.norm_in_ffn: + output = self.layer_norm(output) + + return output, adaptive_prob + + def _ffn_forward(self, inp): + output = inp + + output = self.ffn_net(output) + if self.ffn_residual: + output = output + inp + + if self.norm_in_ffn: + output = self.layer_norm(output) + + return output + + +class GateNet(nn.Module): + def __init__(self, in_features, gate_num, init_identiy_bias): + super(GateNet, self).__init__() + + assert gate_num == 2 + self.weight = nn.Parameter(torch.Tensor(gate_num, in_features)) + self.bias = nn.Parameter(torch.tensor([init_identiy_bias] + [0.] * (gate_num - 1))) + + def forward(self, inputs): + logits = F.linear(inputs, self.weight, self.bias) + prob = F.softmax(logits, dim=-1) + return prob + + +def _rel_shift_uni(x): + # x: qlen x rlen x bsz x n_head + zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), + device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=1) + + x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) + + x = x_padded[1:].view_as(x) + + # if zero_triu: + # ones = torch.ones((x.size(0), x.size(1))) + # x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] + + return x + + +def _rel_shift_bi(x, klen): + # x: qlen x rlen x bsz x n_head + """perform relative shift to form the relative attention score.""" + x_size = x.size() + assert klen * 2 == x_size[1] + + x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3]) + x = torch.narrow(x, dim=0, start=1, length=x.size()[0]-1) + x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3]) + x = torch.narrow(x, dim=1, start=0, length=klen) + + return x + + +def check_rel_shift_bi(): + """in x, 14 means query 1, rel emb at position 4, -32 mean query 3, rel emb at position -2""" + x = torch.tensor([[14, 13, 12, 11, 10, -11, -12, -13], + [24, 23, 22, 21, 20, -21, -22, -23], + [34, 33, 32, 31, 30, -31, -32, -33], + [44, 43, 42, 41, 40, -41, -42, -43]], dtype=torch.float32) + x = x.unsqueeze(-1).unsqueeze(-1) + shifted_x = _rel_shift_bi(x, klen=4) + shifted_x = shifted_x.squeeze(-1).squeeze(-1) + assert torch.equal(shifted_x, + torch.tensor([[10., -11., -12., -13.], + [21., 20., -21., -22.], + [32., 31., 30., -31.], + [43., 42., 41., 40.]])) + return shifted_x + + +def create_pad_mask(lens, max_len): + mask = torch.arange(max_len).to(lens.device) >= lens.unsqueeze(-1) + return mask + + +class RelPartialLearnableMultiHeadAttn(nn.Module): + def __init__(self, n_head, d_model, d_head, *, dropout, dropatt, pre_lnorm=False, ln_eps=1e-5, uni_attn=True, + norm_type='ln', pos_enc='xl'): + super(RelPartialLearnableMultiHeadAttn, self).__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_head = d_head + self.dropout = dropout + + self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) + + self.drop = nn.Dropout(dropout) + self.dropatt = nn.Dropout(dropatt) + self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) + + self.layer_norm = get_norm(norm_type, d_model, ln_eps) + + self.scale = 1 / (d_head ** 0.5) + + self.pre_lnorm = pre_lnorm + + self.uni_attn = uni_attn + + self.pos_enc = pos_enc + if self.pos_enc == 'xl': + self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) + else: + assert self.pos_enc is None + + def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): + qlen, bsz = w.size(0), w.size(1) + + if self.pre_lnorm: + w_norm = self.layer_norm(w) + else: + w_norm = w + w_heads = self.qkv_net(w_norm) + + w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) + + klen = w_head_k.size(0) + + w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + + if mems is not None: + assert self.uni_attn + k_mems, v_mems, real_mlen = mems + new_mems = w_head_k, w_head_v + + w_head_k = torch.cat([k_mems, w_head_k], 0) + w_head_v = torch.cat([v_mems, w_head_v], 0) + else: + new_mems = None + + #### compute attention score + if self.pos_enc == 'xl': + rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head + else: + rw_head_q = w_head_q + attn_score = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head + + if self.pos_enc == 'xl': + rr_head_q = w_head_q + r_r_bias + r_head_k = self.r_net(r) + r_head_k = r_head_k.view(r.size(0), self.n_head, self.d_head) # qlen x n_head x d_head + BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head + BD = self._rel_shift(BD, attn_score.size(1)) + # [qlen x klen x bsz x n_head] + attn_score = attn_score + BD + + attn_score.mul_(self.scale) + + neg_min = float16_min if attn_score.dtype == torch.float16 else float32_min + #### compute attention probability + if self.uni_attn: + # attn_mask: [qlen, klen, 1] -> [qlen, klen, 1, 1] + attn_score = attn_score.float().masked_fill( + attn_mask.unsqueeze(-1), neg_min).type_as(attn_score) + else: + # attn_mask: [klen, bsz] -> [1, klen, bsz, 1] + attn_score = attn_score.masked_fill(attn_mask.unsqueeze(0).unsqueeze(-1), neg_min) + + # [qlen x klen x bsz x n_head] + attn_prob = F.softmax(attn_score, dim=1) + attn_prob = self.dropatt(attn_prob) + + #### compute attention vector + attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) + + # [qlen x bsz x n_head x d_head] + attn_vec = attn_vec.contiguous().view( + attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) + + ##### linear projection + attn_out = self.o_net(attn_vec) + attn_out = self.drop(attn_out) + + if self.pre_lnorm: + ##### residual connection + output = w + attn_out + else: + ##### residual connection + layer normalization + output = self.layer_norm(w + attn_out) + + return output, new_mems + + def _rel_shift(self, x, klen): + if self.uni_attn: + return _rel_shift_uni(x) + else: + return _rel_shift_bi(x, klen) + + +class RelPartialLearnableDecoderLayer(nn.Module): + def __init__(self, n_head, d_model, d_head, d_inner, *, dropout, dropatt, pre_lnorm, ln_eps, uni_attn, norm_type='ln', + pos_enc='xl', adaptive_ffn=None): + super(RelPartialLearnableDecoderLayer, self).__init__() + + self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, d_head, dropout=dropout, dropatt=dropatt, + pre_lnorm=pre_lnorm, ln_eps=ln_eps, uni_attn=uni_attn, + norm_type=norm_type, pos_enc=pos_enc) + self.use_adaptive_ffn = adaptive_ffn is not None + if not self.use_adaptive_ffn: + self.pos_ff = PositionwiseFF(d_model, d_inner, dropout=dropout, pre_lnorm=pre_lnorm, ln_eps=ln_eps, norm_type=norm_type) + else: + assert not pre_lnorm and norm_type == 'ln' + self.pos_ff = AdaptiveFFN(d_model, d_inner, dropout=dropout, ln_eps=ln_eps, **adaptive_ffn) + + def forward(self, dec_inp, r, r_w_bias, r_r_bias, *, dec_attn_mask=None, pad_mask=None, mems=None): + + output, new_mems = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, + attn_mask=dec_attn_mask, + mems=mems) + if not self.use_adaptive_ffn: + output = self.pos_ff(output) + ada_prob = None + else: + output, ada_prob = self.pos_ff(output, pad_mask=pad_mask) + + return output, new_mems, (ada_prob,) + + +class RelTransformerBlock(nn.Module): + def __init__(self, n_layer, d_model, n_head, d_head, d_inner, dropout, dropout_att, + pre_lnorm=False, norm_output=False, ln_eps=1e-5, uni_attn=True, norm_type='ln', pos_enc='xl', + layer_drop=0.0, adaptive_ffn=None): + super(RelTransformerBlock, self).__init__() + + self.n_layer = n_layer + self.d_model = d_model + self.n_head = n_head + self.d_head = d_head + + self.drop = nn.Dropout(dropout) + + self.uni_attn = uni_attn + self.att_trunc_len = -1 + + self.clamp_len = 0 + + self.layer_drop = layer_drop + + self.layers = nn.ModuleList() + for i in range(n_layer): + self.layers.append( + RelPartialLearnableDecoderLayer( + n_head, d_model, d_head, d_inner, dropout=dropout, + dropatt=dropout_att, pre_lnorm=pre_lnorm, ln_eps=ln_eps, uni_attn=uni_attn, norm_type=norm_type, + pos_enc=pos_enc, adaptive_ffn=adaptive_ffn) + ) + + if norm_output: + self.output_norm = nn.LayerNorm(d_model, eps=ln_eps) + else: + self.output_norm = identity + + self.use_adaptive_ffn = adaptive_ffn is not None + if self.use_adaptive_ffn: + self.identity_loss_weight = adaptive_ffn.identity_loss_weight + + self.pos_enc = pos_enc + self._create_params() + + def _create_params(self): + if self.pos_enc == 'xl': + self.pos_emb = PositionalEmbedding(self.d_model) + self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) + self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) + else: + assert self.pos_enc is None + self.pos_emb = None + self.r_w_bias = None + self.r_r_bias = None + + def _forward(self, dec_inp, lens=None, mems=None): + is_decoding = mems is not None + qlen, bsz, _ = dec_inp.size() + + mlen = mems[0][0].size(0) if mems is not None else 0 + klen = mlen + qlen + + if not self.uni_attn or self.use_adaptive_ffn: + assert lens is not None + pad_mask = create_pad_mask(lens, max_len=qlen) + # [B, L] -> [L, B] + pad_mask = pad_mask.transpose(0, 1).contiguous() + else: + pad_mask = None + + if self.uni_attn: + dec_attn_mask = torch.triu( + dec_inp.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None] + else: + assert pad_mask is not None + dec_attn_mask = pad_mask + + hids = [] + new_kv_mems = [] + core_out = dec_inp + + if self.pos_enc == 'xl': + if self.uni_attn: + pos_s, pos_e = klen-1, -1 + else: + pos_s, pos_e = klen, -qlen + pos_seq = torch.arange(pos_s, pos_e, -1.0, device=dec_inp.device, + dtype=dec_inp.dtype) + if self.clamp_len > 0: + pos_seq.clamp_(max=self.clamp_len) + pos_emb = self.pos_emb(pos_seq) + + pos_emb = self.drop(pos_emb) + else: + pos_emb = None + + if self.use_adaptive_ffn: + adaptive_prob = [] + else: + adaptive_prob = None + + # hids.append(core_out) + for i, layer in enumerate(self.layers): + if self.layer_drop > 0 and self.training and np.random.random() < self.layer_drop: + continue + + mems_i = None if not is_decoding else mems[i] + core_out, kv_mems, extra = layer(core_out, pos_emb, self.r_w_bias, + self.r_r_bias, dec_attn_mask=dec_attn_mask, pad_mask=pad_mask, mems=mems_i) + # hids.append(core_out) + new_kv_mems.append(kv_mems) + + if self.use_adaptive_ffn: + if extra[0] is not None: + adaptive_prob.append(extra[0]) + + core_out = self.output_norm(core_out) + + if self.use_adaptive_ffn: + if len(adaptive_prob) == 0: + assert self.training + adaptive_ffn_loss = torch.tensor(0., dtype=dec_inp.dtype, device=dec_inp.device) + else: + adaptive_prob_t = torch.stack(adaptive_prob) + zero_guard_eps = 1e-12 + adaptive_logp = torch.log(adaptive_prob_t + zero_guard_eps) + # pad_mask: [L, B] => [1, L, B, 1] + adaptive_logp = adaptive_logp.masked_fill(pad_mask.unsqueeze(-1).unsqueeze(0), 0.) + avg_adaptive_logp = adaptive_logp.sum() / (len(adaptive_prob) * lens.sum()) + adaptive_ffn_loss = -avg_adaptive_logp * self.identity_loss_weight + else: + adaptive_ffn_loss = None + + new_mems = [] + if is_decoding: + new_mems = self._update_decode_mems(new_kv_mems, mems, [self.att_trunc_len] * len(self.layers)) + + return core_out, new_mems, (adaptive_ffn_loss,) + + def forward(self, data, *, lens=None, mems=None): + # data: [T, B, D] + output, new_mems, extra = self._forward(data, lens=lens, mems=mems) + + if mems is None: + assert len(new_mems) == 0 + + return output, tuple(new_mems), extra + + def _update_decode_mems(self, step_mem, prev_mems, mem_len): + assert prev_mems is not None and len(step_mem) == len(prev_mems) + + with torch.no_grad(): + new_mems = [] + for i in range(len(step_mem)): + mem_len_i = mem_len[i] + k_mem, v_mem = step_mem[i] + k_prev_mem, v_prev_mem, real_mlen = prev_mems[i] + new_k_mem = torch.cat([k_prev_mem, k_mem], 0) + new_v_mem = torch.cat([v_prev_mem, v_mem], 0) + real_mlen = real_mlen + k_mem.size(0) + if mem_len_i > 0: + new_k_mem = self.neg_slice(new_k_mem, mem_len, True) + new_v_mem = self.neg_slice(new_v_mem, mem_len, True) + real_mlen = torch.min(real_mlen, mem_len) + new_mems.append((new_k_mem, new_v_mem, real_mlen)) + return new_mems + + +def identity(x): + return x diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/multi_layer_perceptron.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/multi_layer_perceptron.py new file mode 100644 index 0000000000000000000000000000000000000000..76c06bf23ea647148d4eaa0a9686e631c169a436 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/multi_layer_perceptron.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class MultiLayerPerceptron(torch.nn.Module): + """ + A simple MLP that can either be used independently or put on top + of pretrained models (such as BERT) and act as a classifier. + Args: + hidden_size (int): the size of each layer + num_classes (int): number of output classes + num_layers (int): number of layers + activation (str): type of activations for layers in between + log_softmax (bool): whether to add a log_softmax layer before output + """ + + def __init__( + self, + hidden_size: int, + num_classes: int, + num_layers: int = 2, + activation: str = 'relu', + log_softmax: bool = True, + ): + super().__init__() + self.layers = 0 + for _ in range(num_layers - 1): + layer = torch.nn.Linear(hidden_size, hidden_size) + setattr(self, f'layer{self.layers}', layer) + setattr(self, f'layer{self.layers + 1}', getattr(torch, activation)) + self.layers += 2 + layer = torch.nn.Linear(hidden_size, num_classes) + setattr(self, f'layer{self.layers}', layer) + self.layers += 1 + self.log_softmax = log_softmax + + @property + def last_linear_layer(self): + return getattr(self, f'layer{self.layers - 1}') + + def forward(self, hidden_states): + output_states = hidden_states[:] + for i in range(self.layers): + output_states = getattr(self, f'layer{i}')(output_states) + + if self.log_softmax: + output_states = torch.log_softmax(output_states, dim=-1) + return output_states diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/normalization.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..e80e2bef0e7795479d65b3a775651a0607fcc335 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/normalization.py @@ -0,0 +1,37 @@ +import math +import numbers + +import torch +import torch.nn as nn + + +class LayerVarNorm(nn.Module): + __constants__ = ['normalized_shape', 'weight', 'eps', 'elementwise_affine'] + + def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=False): + super(LayerVarNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = math.sqrt(eps) # eps is added directly on std, i.e., outside sqrt(var) + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) + else: + self.register_parameter('weight', None) + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, input): + std = torch.std(input, dim=-1, unbiased=False, keepdim=True) + output = input / (std + self.eps) + if self.elementwise_affine: + output = output * self.weight + return output + + def extra_repr(self): + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) \ No newline at end of file diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/rnn.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..ea50d7e1918e4603e4eb044a09cc721b0c75df87 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/rnn.py @@ -0,0 +1,510 @@ +# Copyright (c) 2019, Myrtle Software Limited. All rights reserved. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + + +def rnn( + input_size: int, + hidden_size: int, + num_layers: int, + norm: Optional[str] = None, + forget_gate_bias: Optional[float] = 1.0, + dropout: Optional[float] = 0.0, + norm_first_rnn: Optional[bool] = None, + t_max: Optional[int] = None, +) -> torch.nn.Module: + """ + Utility function to provide unified interface to common LSTM RNN modules. + + Args: + input_size: Input dimension. + + hidden_size: Hidden dimension of the RNN. + + num_layers: Number of RNN layers. + + norm: Optional string representing type of normalization to apply to the RNN. + Supported values are None, batch and layer. + + forget_gate_bias: float, set by default to 1.0, which constructs a forget gate + initialized to 1.0. + Reference: + [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) + + dropout: Optional dropout to apply to end of multi-layered RNN. + + norm_first_rnn: Whether to normalize the first RNN layer. + + t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization + of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course + of training. + Reference: + [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + + Returns: + A RNN module + """ + if norm not in [None, "batch", "layer"]: + raise ValueError(f"unknown norm={norm}") + + if norm is None: + return LSTMDropout( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + ) + + if norm == "batch": + return BNRNNSum( + input_size=input_size, + hidden_size=hidden_size, + rnn_layers=num_layers, + batch_norm=True, + dropout=dropout, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + norm_first_rnn=norm_first_rnn, + ) + + if norm == "layer": + return torch.jit.script( + ln_lstm( # torch.jit.script( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + ) + ) + + +class OverLastDim(torch.nn.Module): + """Collapses a tensor to 2D, applies a module, and (re-)expands the tensor. + An n-dimensional tensor of shape (s_1, s_2, ..., s_n) is first collapsed to + a tensor with shape (s_1*s_2*...*s_n-1, s_n). The module is called with + this as input producing (s_1*s_2*...*s_n-1, s_n') --- note that the final + dimension can change. This is expanded to (s_1, s_2, ..., s_n-1, s_n') and + returned. + Args: + module (torch.nn.Module): Module to apply. Must accept a 2D tensor as + input and produce a 2D tensor as output, optionally changing the + size of the last dimension. + """ + + def __init__(self, module: torch.nn.Module): + super().__init__() + self.module = module + + def forward(self, x: torch.Tensor) -> torch.Tensor: + *dims, _ = x.size() + + reduced_dims = 1 + for dim in dims: + reduced_dims *= dim + + x = x.view(reduced_dims, -1) + x = self.module(x) + x = x.view(*dims, -1) + return x + + +class LSTMDropout(torch.nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + dropout: Optional[float], + forget_gate_bias: Optional[float], + t_max: Optional[int] = None, + ): + """Returns an LSTM with forget gate bias init to `forget_gate_bias`. + Args: + input_size: See `torch.nn.LSTM`. + hidden_size: See `torch.nn.LSTM`. + num_layers: See `torch.nn.LSTM`. + dropout: See `torch.nn.LSTM`. + + forget_gate_bias: float, set by default to 1.0, which constructs a forget gate + initialized to 1.0. + Reference: + [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) + + t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization + of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course + of training. + Reference: + [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) + + Returns: + A `torch.nn.LSTM`. + """ + super(LSTMDropout, self).__init__() + + self.lstm = torch.nn.LSTM( + input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, + ) + + if t_max is not None: + # apply chrono init + for name, v in self.lstm.named_parameters(): + if 'bias' in name: + p = getattr(self.lstm, name) + n = p.nelement() + hidden_size = n // 4 + p.data.fill_(0) + p.data[hidden_size : 2 * hidden_size] = torch.log( + torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1) + ) + # forget gate biases = log(uniform(1, Tmax-1)) + p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size] + # input gate biases = -(forget gate biases) + + elif forget_gate_bias is not None: + for name, v in self.lstm.named_parameters(): + if "bias_ih" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) + if "bias_hh" in name: + bias = getattr(self.lstm, name) + bias.data[hidden_size : 2 * hidden_size].fill_(0) + + self.dropout = torch.nn.Dropout(dropout) if dropout else None + + def forward( + self, x: torch.Tensor, h: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + x, h = self.lstm(x, h) + + if self.dropout: + x = self.dropout(x) + + return x, h + + +class RNNLayer(torch.nn.Module): + """A single RNNLayer with optional batch norm.""" + + def __init__( + self, + input_size: int, + hidden_size: int, + rnn_type: torch.nn.Module = torch.nn.LSTM, + batch_norm: bool = True, + forget_gate_bias: Optional[float] = 1.0, + t_max: Optional[int] = None, + ): + super().__init__() + + if batch_norm: + self.bn = OverLastDim(torch.nn.BatchNorm1d(input_size)) + + if isinstance(rnn_type, torch.nn.LSTM) and not batch_norm: + # batch_norm will apply bias, no need to add a second to LSTM + self.rnn = LSTMDropout( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + dropout=0.0, + forget_gate_bias=forget_gate_bias, + t_max=t_max, + ) + else: + self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, bias=not batch_norm) + + def forward( + self, x: torch.Tensor, hx: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if hasattr(self, 'bn'): + x = x.contiguous() + x = self.bn(x) + x, h = self.rnn(x, hx=hx) + return x, h + + def _flatten_parameters(self): + self.rnn.flatten_parameters() + + +class BNRNNSum(torch.nn.Module): + """RNN wrapper with optional batch norm. + Instantiates an RNN. If it is an LSTM it initialises the forget gate + bias =`lstm_gate_bias`. Optionally applies a batch normalisation layer to + the input with the statistics computed over all time steps. If dropout > 0 + then it is applied to all layer outputs except the last. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + rnn_type: torch.nn.Module = torch.nn.LSTM, + rnn_layers: int = 1, + batch_norm: bool = True, + dropout: Optional[float] = 0.0, + forget_gate_bias: Optional[float] = 1.0, + norm_first_rnn: bool = False, + t_max: Optional[int] = None, + ): + super().__init__() + self.rnn_layers = rnn_layers + + self.layers = torch.nn.ModuleList() + for i in range(rnn_layers): + final_layer = (rnn_layers - 1) == i + + self.layers.append( + RNNLayer( + input_size, + hidden_size, + rnn_type=rnn_type, + batch_norm=batch_norm and (norm_first_rnn or i > 0), + forget_gate_bias=forget_gate_bias, + t_max=t_max, + ) + ) + + if dropout is not None and dropout > 0.0 and not final_layer: + self.layers.append(torch.nn.Dropout(dropout)) + + input_size = hidden_size + + def forward( + self, x: torch.Tensor, hx: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + hx = self._parse_hidden_state(hx) + + hs = [] + cs = [] + rnn_idx = 0 + for layer in self.layers: + if isinstance(layer, torch.nn.Dropout): + x = layer(x) + else: + x, h_out = layer(x, hx=hx[rnn_idx]) + hs.append(h_out[0]) + cs.append(h_out[1]) + rnn_idx += 1 + del h_out + + h_0 = torch.stack(hs, dim=0) + c_0 = torch.stack(cs, dim=0) + return x, (h_0, c_0) + + def _parse_hidden_state( + self, hx: Optional[Tuple[torch.Tensor, torch.Tensor]] + ) -> Union[List[None], List[Tuple[torch.Tensor, torch.Tensor]]]: + """ + Dealing w. hidden state: + Typically in pytorch: (h_0, c_0) + h_0 = ``[num_layers * num_directions, batch, hidden_size]`` + c_0 = ``[num_layers * num_directions, batch, hidden_size]`` + """ + if hx is None: + return [None] * self.rnn_layers + else: + h_0, c_0 = hx + + if h_0.shape[0] != self.rnn_layers: + raise ValueError( + 'Provided initial state value `h_0` must be of shape : ' + '[num_layers * num_directions, batch, hidden_size]' + ) + + return [(h_0[i], c_0[i]) for i in range(h_0.shape[0])] + + def _flatten_parameters(self): + for layer in self.layers: + if isinstance(layer, (torch.nn.LSTM, torch.nn.GRU, torch.nn.RNN)): + layer._flatten_parameters() + + +class StackTime(torch.nn.Module): + """ + Stacks time within the feature dim, so as to behave as a downsampling operation. + """ + + def __init__(self, factor: int): + super().__init__() + self.factor = int(factor) + + def forward(self, x: List[Tuple[torch.Tensor]]) -> (torch.Tensor, torch.Tensor): + # T, B, U + x, x_lens = x + seq = [x] + for i in range(1, self.factor): + tmp = torch.zeros_like(x) + tmp[:-i, :, :] = x[i:, :, :] + seq.append(tmp) + x_lens = torch.ceil(x_lens.float() / self.factor).int() + return torch.cat(seq, dim=2)[:: self.factor, :, :], x_lens + + +def ln_lstm( + input_size: int, + hidden_size: int, + num_layers: int, + dropout: Optional[float], + forget_gate_bias: Optional[float], + t_max: Optional[int], +) -> torch.nn.Module: + """Returns a ScriptModule that mimics a PyTorch native LSTM.""" + # The following are not implemented. + if dropout is not None and dropout != 0.0: + raise ValueError('`dropout` not supported with LayerNormLSTM') + + if t_max is not None: + raise ValueError("LayerNormLSTM does not support chrono init") + + return StackedLSTM( + num_layers, + LSTMLayer, + first_layer_args=[LayerNormLSTMCell, input_size, hidden_size, forget_gate_bias], + other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size, forget_gate_bias], + ) + + +class LSTMLayer(torch.nn.Module): + def __init__(self, cell, *cell_args): + super(LSTMLayer, self).__init__() + self.cell = cell(*cell_args) + + def forward( + self, input: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + inputs = input.unbind(0) + outputs = [] + for i in range(len(inputs)): + out, state = self.cell(inputs[i], state) + outputs += [out] + return torch.stack(outputs), state + + +class LayerNormLSTMCell(torch.nn.Module): + def __init__(self, input_size, hidden_size, forget_gate_bias): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = torch.nn.Parameter(torch.randn(4 * hidden_size, input_size)) + self.weight_hh = torch.nn.Parameter(torch.randn(4 * hidden_size, hidden_size)) + + # LayerNorm provide learnable biases + self.layernorm_i = torch.nn.LayerNorm(4 * hidden_size) + self.layernorm_h = torch.nn.LayerNorm(4 * hidden_size) + self.layernorm_c = torch.nn.LayerNorm(hidden_size) + + self.reset_parameters() + + self.layernorm_i.bias.data[hidden_size : 2 * hidden_size].fill_(0.0) + self.layernorm_h.bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias) + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + torch.nn.init.uniform_(weight, -stdv, stdv) + + def forward( + self, input: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + hx, cx = state + igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) + hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) + gates = igates + hgates + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) + hy = outgate * torch.tanh(cy) + + return hy, (hy, cy) + + +def init_stacked_lstm( + num_layers: int, layer: torch.nn.Module, first_layer_args: List, other_layer_args: List +) -> torch.nn.ModuleList: + layers = [layer(*first_layer_args)] + [layer(*other_layer_args) for _ in range(num_layers - 1)] + return torch.nn.ModuleList(layers) + + +class StackedLSTM(torch.nn.Module): + def __init__(self, num_layers: int, layer: torch.nn.Module, first_layer_args: List, other_layer_args: List): + super(StackedLSTM, self).__init__() + self.layers: torch.nn.ModuleList = init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args) + + def forward( + self, input: torch.Tensor, states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + if states is None: + temp_states: List[Tuple[torch.Tensor, torch.Tensor]] = [] + batch = input.size(1) + for layer in self.layers: + temp_states.append( + ( + torch.zeros(batch, layer.cell.hidden_size, dtype=input.dtype, device=input.device), + torch.zeros(batch, layer.cell.hidden_size, dtype=input.dtype, device=input.device), + ) + ) + + states = temp_states + + output_states: List[Tuple[torch.Tensor, torch.Tensor]] = [] + output = input + for i, rnn_layer in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states.append(out_state) + i += 1 + return output, output_states + + +def label_collate(labels, device=None): + """Collates the label inputs for the rnn-t prediction network. + If `labels` is already in torch.Tensor form this is a no-op. + + Args: + labels: A torch.Tensor List of label indexes or a torch.Tensor. + device: Optional torch device to place the label on. + + Returns: + A padded torch.Tensor of shape (batch, max_seq_len). + """ + + if isinstance(labels, torch.Tensor): + assert labels.dtype == torch.int32 or labels.dtype == torch.int64 + return labels.type(torch.int64) + if not isinstance(labels, (list, tuple)): + raise ValueError(f"`labels` should be a list or tensor not {type(labels)}") + + batch_size = len(labels) + max_len = max(len(label) for label in labels) + + cat_labels = np.full((batch_size, max_len), fill_value=0.0, dtype=np.int32) + for e, l in enumerate(labels): + cat_labels[e, : len(l)] = l + labels = torch.tensor(cat_labels, dtype=torch.int64, device=device) + + return labels diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/transformer_utils.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/transformer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..467ee9011915a78cd1956229408d1a358708cd65 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/transformer_utils.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +__all__ = ['NEG_INF', 'form_attention_mask', 'transformer_weights_init', 'mask_padded_tokens'] + +NEG_INF = -10000.0 + + +def form_attention_mask(input_mask, diagonal=None): + """ + Build attention mask with optional masking of future tokens we forbid + to attend to (e.g. as it is in Transformer decoder). + + Args: + input_mask: binary mask of size B x L with 1s corresponding to valid + tokens and 0s corresponding to padding tokens + diagonal: diagonal where triangular future mask starts + None -- do not mask anything + 0 -- regular translation or language modeling future masking + 1 -- query stream masking as in XLNet architecture + Returns: + attention_mask: mask of size B x 1 x L x L with 0s corresponding to + tokens we plan to attend to and -10000 otherwise + """ + + if input_mask is None: + return None + attn_shape = (1, input_mask.shape[1], input_mask.shape[1]) + attn_mask = input_mask.byte().unsqueeze(1) + if diagonal is not None: + future_mask = torch.tril(torch.ones(attn_shape).byte().to(input_mask.device), diagonal) + attn_mask = attn_mask & future_mask + attention_mask = (1 - attn_mask.to(torch.float)) * NEG_INF + return attention_mask.unsqueeze(1) + + +def transformer_weights_init(module, std_init_range=0.02, xavier=True): + """ + Initialize different weights in Transformer model. + + Args: + module: torch.nn.Module to be initialized + std_init_range: standard deviation of normal initializer + xavier: if True, xavier initializer will be used in Linear layers + as was proposed in AIAYN paper, otherwise normal initializer + will be used (like in BERT paper) + """ + + if isinstance(module, nn.Linear): + if xavier: + nn.init.xavier_uniform_(module.weight) + else: + nn.init.normal_(module.weight, mean=0.0, std=std_init_range) + if module.bias is not None: + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=std_init_range) + elif isinstance(module, nn.LayerNorm): + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + + +def mask_padded_tokens(tokens, pad_id): + mask = tokens != pad_id + return mask diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/utils.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0542e5dce94ee7c9d9107754d8dc1c575d7c46f1 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/utils.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from typing import List + +__all__ = ['if_exist', '_compute_softmax'] + + +def if_exist(outfold: str, files: List[str]): + """ + Returns true if all given files exist in the given folder + Args: + outfold: folder path + files: list of file names relative to outfold + """ + if not os.path.exists(outfold): + return False + for file in files: + if not os.path.exists(f'{outfold}/{file}'): + return False + return True + + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6f0950d6ac003d6534fb0f8c8be731b33b30c3 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/char_tokenizer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/char_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..27674256f315838bdfa8f6f3f60444306281ff0d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/char_tokenizer.py @@ -0,0 +1,521 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import warnings +from collections import Counter +from enum import Enum +from pathlib import Path +from typing import Dict, List, NewType, Optional, Union + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + +__all__ = ['CharTokenizer'] + + +NUMBER_OF_CHARACTERS_READ_BUFFER_SIZE = 10 ** 7 + + +class SpecialTokenString(Enum): + MASK = 'mask' + BOS = 'bos' + EOS = 'eos' + PAD = 'pad' + SEP = 'sep' + CLS = 'cls' + UNK = 'unk' + + @classmethod + def has_value(cls, value): + return value in cls._value2member_map_ + + +SpecialTokenStringType = NewType('SpecialTokenString', SpecialTokenString) + + +class CharTokenizer(TokenizerSpec): + rf""" + Each character is a token. + Args: + vocab_file: path to file with vocabulary for a tokenizer. The file consists of valid Python string literals + separated by the new line character. Such literals must contain 1 character. Examples of valid Python + literals: ``'a'``, ``'\n'``, ``"'"``, ``'ж'``, ``'\u8976'``. Optionally the first line in the file can be a + JSON dictionary of special tokens. The keys of the special tokens dictionary are ``'mask_token'``, + ``'bos_token'`` and so on. Some special tokens names can be omitted in the special tokens dictionary line. + A file ``vocab_file`` has to be in ``'utf-8'`` encoding. + mask_token: mask token. The following is applicable to all special tokens. Parameter ``mask_token`` is used + for adding mask token to vocabulary or for modification of mask token present in special tokens dictionary + in the first line of file ``vocab_file``. Parameter ``mask_token`` can be either of type ``bool`` or a + ``str`` of length 1. + + If ``mask_token`` is ``bool`` it has to be ``False``. If ``mask_token`` is ``True`` an exception is raised. + If ``mask_token`` is ``False`` and ``mask_token`` is present in special tokens dictionary in vocabulary + file ``vocab_file``, then ``mask_token`` is remove from special tokens dictionary. + + If the parameter ``mask_token`` is a string, then such strings in the input sequence are interpreted as + mask tokens. + bos_token: the beginning of sequence token. See more in ``mask_token`` parameter description. + eos_token: the end of sequence token. Usually equal to sep_token. See more in ``mask_token`` parameter + description. + pad_token: token to use for padding. See more in ``mask_token`` parameter description. + sep_token: token used for separating sequences. See more in ``mask_token`` parameter description. + cls_token: class token. Usually equal to bos_token. See more in ``mask_token`` parameter description. + unk_token: token to use for unknown tokens. If the parameter ``unk_token`` is set and there is a character + in the input of ``text_to_ids`` of ``text_to_tokens`` methods which is not in the vocabulary, then + such an unknown character is tokenized into ``unk_token``. If the parameter ``unk_token`` is ``False``, + then unknown tokens are discarded. See more in ``mask_token`` parameter description. + special_token_to_prepend: special token to prepend to the output of ``text_to_ids`` of ``text_to_tokens`` + methods. This option can be used if you decide to add EOS and BOS tokens to the input on the stage of + tokenization. Possible options are: {[None] + [e.value for e in SpecialTokenString]}. + special_token_to_append: special token to append to the output of ``text_to_ids`` of ``text_to_tokens`` + methods. See more in the description of ``special_token_to_prepend`` parameter. + special_tokens_to_remove_while_decoding: which special tokens are remove before detokenization. If this + parameter equals ``'all'``, then all special tokens are removed. The parameter + ``special_tokens_to_remove_while_decoding`` can also be a list of values from this set + {set(e.value for e in SpecialTokenString)}. + """ + + def __init__( + self, + vocab_file: str, + mask_token: Optional[Union[str, bool]] = None, + bos_token: Optional[Union[str, bool]] = None, + eos_token: Optional[Union[str, bool]] = None, + pad_token: Optional[Union[str, bool]] = None, + sep_token: Optional[Union[str, bool]] = None, + cls_token: Optional[Union[str, bool]] = None, + unk_token: Optional[Union[str, bool]] = None, + special_token_to_prepend: Optional[SpecialTokenStringType] = None, + special_token_to_append: Optional[SpecialTokenStringType] = None, + special_tokens_to_remove_while_decoding: Union[List[SpecialTokenStringType], str] = 'all', + ): + vocab_file = Path(vocab_file).expanduser() + with vocab_file.open(encoding='utf-8') as f: + first_line = f.readline() + if first_line[0] == '{': + special_tokens_dict = json.loads(first_line) + self.check_special_tokens_dict_from_file(special_tokens_dict, vocab_file) + vocab_list = f.readlines() + else: + special_tokens_dict = {} + vocab_list = [first_line] + f.readlines() + special_tokens_dict = self.update_special_tokens_dict( + special_tokens_dict, mask_token, bos_token, eos_token, pad_token, sep_token, cls_token, unk_token + ) + for e in SpecialTokenString: + name = e.value + '_token' + setattr(self, name, special_tokens_dict[name] if name in special_tokens_dict else None) + for k, v in special_tokens_dict.items(): + setattr(self, k, v) + for value, name in [ + (special_token_to_prepend, 'special_token_to_prepend'), + (special_token_to_append, 'special_token_to_append'), + ]: + self.check_special_token_name(name, value, special_tokens_dict) + setattr(self, name, value + '_token' if isinstance(value, str) else value) + self.vocab = {} + count = 0 + for v in special_tokens_dict.values(): + self.vocab[v] = count + count += 1 + for i, token in enumerate(vocab_list): + token = eval(token.strip()) + self.check_token_from_file(token, vocab_file, i) + if token not in self.vocab: + self.vocab[token] = count + count += 1 + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.vocab_size = len(self.vocab) + self.check_special_tokens_to_remove_while_decoding( + special_tokens_to_remove_while_decoding, special_tokens_dict + ) + self.special_token_ids_to_remove_while_decoding = ( + self.tokens_to_ids([v for v in special_tokens_dict.values()]) + if special_tokens_to_remove_while_decoding == 'all' + else [getattr(self, e + '_id') for e in special_tokens_to_remove_while_decoding] + ) + + @classmethod + def check_special_tokens_dict_from_file(cls, special_tokens_dict, vocab_file): + for k, v in special_tokens_dict.items(): + if k[-6:] != '_token' or not SpecialTokenString.has_value(k[:-6]): + raise ValueError( + f"Unsupported key {repr(k)} in special tokens dictionary in vocabulary file {vocab_file} " + f"(first line). Supported keys are {[e.value + '_token' for e in SpecialTokenString]}." + ) + if not isinstance(v, str): + raise ValueError( + f"Values of special tokens dictionary in vocabulary file {vocab_file} (first line) has to belong " + f"to type `str`, whereas type of item '{k}' value {repr(v)} is `{type(v)}`." + ) + elif len(v) == 0: + raise ValueError( + f"Values of special tokens dictionary in vocabulary file {vocab_file} (first line) has to not " + f"empty strings, whereas value of item '{k}' is an empty string." + ) + cls.check_special_tokens_dict_for_duplicate_values( + special_tokens_dict, f"Loaded from vocabulary file {vocab_file}" + ) + + @staticmethod + def check_special_tokens_dict_for_duplicate_values(special_tokens_dict, err_msg_prefix): + if len(special_tokens_dict) != len(set(special_tokens_dict.values())): + tokens_with_equal_values = [] + duplicate_values = [] + for k, v in list(reversed(list(special_tokens_dict.items())))[:-1]: + tokens = [k] + for kk, vv in special_tokens_dict.items(): + if kk == k: + break + if v == vv: + tokens.append(kk) + if len(tokens) > 1: + duplicate_values.append(v) + tokens_with_equal_values.append(tokens) + if duplicate_values: + dup_values_msg = '. '.join( + [f"Tokens {t} have value '{v}'" for t, v in zip(tokens_with_equal_values, duplicate_values)] + ) + raise ValueError( + err_msg_prefix + f" special tokens dictionary has duplicate values. " + dup_values_msg + ) + + @classmethod + def update_special_tokens_dict( + cls, + init_special_tokens_dict: Dict[str, str], + mask_token: Optional[Union[str, bool]] = None, + bos_token: Optional[Union[str, bool]] = None, + eos_token: Optional[Union[str, bool]] = None, + pad_token: Optional[Union[str, bool]] = None, + sep_token: Optional[Union[str, bool]] = None, + cls_token: Optional[Union[str, bool]] = None, + unk_token: Optional[Union[str, bool]] = None, + ): + special_tokens_dict = init_special_tokens_dict.copy() + for value, name in zip( + [pad_token, unk_token, bos_token, eos_token, sep_token, mask_token, cls_token], + ['pad_token', 'unk_token', 'bos_token', 'eos_token', 'sep_token', 'mask_token', 'cls_token'], + ): + if value is not None: + if isinstance(value, bool): + if value: + raise ValueError( + f"If `CharTokenizer` constructor parameter `{name}` is `bool` it has to be `False`" + ) + else: + if name in special_tokens_dict: + del special_tokens_dict[name] + else: + warnings.warn( + f"Cannot remove special token `{name}` since it is not in special tokens dictionary " + f"{special_tokens_dict}." + ) + elif not isinstance(value, str): + raise ValueError( + f"`CharTokenizer` constructor parameter `{name}` has to be either `False` or belong to type " + f"`str`, whereas type of `{name}` is `{type(value)}`." + ) + else: + special_tokens_dict[name] = value + cls.check_special_tokens_dict_for_duplicate_values( + special_tokens_dict, + "After updating special tokens dictionary with tokens passed in `CharTokenizer` constructor parameters", + ) + return special_tokens_dict + + @staticmethod + def check_token_from_file(token, vocab_file, line_i): + if not isinstance(token, str) or isinstance(token, str) and len(token) != 1: + raise ValueError( + f"Each line in vocabulary have to be a Python string literal containing 1 character. " + f"Encountered {repr(token)} on line {line_i} in file {vocab_file}." + ) + + @staticmethod + def check_special_token_name(parameter_name, value, special_tokens_dict): + if value is not None: + if not SpecialTokenString.has_value(value): + raise ValueError( + f"Value {repr(value)} of parameter `{parameter_name}` is wrong. Supported values are " + f"{[e.value for e in SpecialTokenString]}." + ) + elif value + '_token' not in special_tokens_dict: + raise ValueError( + f"You should provide `{value + '_token'}` parameter to `CharTokenizer` constructor if " + f"you wish to pass token {repr(value)} in parameter `{parameter_name}`." + ) + + @staticmethod + def check_special_tokens_to_remove_while_decoding(special_tokens_to_remove_while_decoding, special_tokens_dict): + if isinstance(special_tokens_to_remove_while_decoding, list): + for i, value in enumerate(special_tokens_to_remove_while_decoding): + if not SpecialTokenString.has_value(value): + raise ValueError( + f'Wrong element with value {repr(value)} in position {i} of parameter ' + f'`special_tokens_to_remove_while_decoding` of `CharTokenizer` constructor. Supported values ' + f'are {[e.value for e in SpecialTokenString]}.' + ) + elif value + '_token' not in special_tokens_dict: + raise ValueError( + f"You should provide `{value + '_token'}` parameter to `CharTokenizer` constructor if " + f"you wish to pass token {repr(value)} in parameter `special_tokens_to_remove_while_decoding`. " + f"`{value + '_token'}` was detected in position {i} in " + f"`special_tokens_to_remove_while_decoding`." + ) + elif ( + isinstance(special_tokens_to_remove_while_decoding, str) + and special_tokens_to_remove_while_decoding != 'all' + or not isinstance(special_tokens_to_remove_while_decoding, str) + ): + raise ValueError( + f"Parameter `special_tokens_to_remove_while_decoding` of `CharTokenizer` constructor has to be " + f"equal to a string 'all' or be a list of values from set {set(e.value for e in SpecialTokenString)} " + f"whereas `special_tokens_to_remove_while_decoding={repr(special_tokens_to_remove_while_decoding)}`" + ) + + def text_to_tokens(self, text: str) -> List[str]: + token_candidates = [char for char in text] + tokens = [] + if self.special_token_to_prepend is not None: + tokens.append(getattr(self, self.special_token_to_prepend)) + for i, token in enumerate(token_candidates): + if token in self.vocab: + tokens.append(token) + elif self.unk_token is not None: + tokens.append(self.unk_token) + else: + warnings.warn( + f"Character {repr(token)} in position {i} is not present in vocabulary and no `` token was " + f"set. Character {repr(token)} is discarded." + ) + if self.special_token_to_append is not None: + tokens.append(getattr(self, self.special_token_to_append)) + return tokens + + def tokens_to_text(self, tokens: List[str]) -> str: + return self.ids_to_text(self.tokens_to_ids(tokens)) + + def text_to_ids(self, text: str) -> List[int]: + ids = [self.vocab[token] for token in self.text_to_tokens(text)] + return ids + + def ids_to_text(self, ids: List[int]) -> str: + ids_ = [id_ for id_ in ids if id_ not in self.special_token_ids_to_remove_while_decoding] + return "".join(self.ids_to_tokens(ids_)) + + def tokens_to_ids(self, tokens: List[str]) -> List[int]: + return [self.vocab[token] for token in tokens] + + def token_to_id(self, token: str) -> int: + return self.vocab[token] + + def ids_to_tokens(self, ids: List[int]) -> List[str]: + return [self.inv_vocab[id] for id in ids] + + @staticmethod + def check_special_token_id_getting(special_token, id_name): + if special_token is None: + token_param = id_name[:-3] + '_token' + raise ValueError( + f"Cannot return `{id_name}` since `{token_param}` is not set. To obtain `{id_name}` you need to pass " + f"parameter `{token_param}` to `CharTokenizer` constructor." + ) + + @property + def pad_id(self): + self.check_special_token_id_getting(self.pad_token, 'pad_id') + return self.vocab[self.pad_token] + + @property + def bos_id(self): + self.check_special_token_id_getting(self.bos_token, 'bos_id') + return self.vocab[self.bos_token] + + @property + def eos_id(self): + self.check_special_token_id_getting(self.eos_token, 'eos_id') + return self.vocab[self.eos_token] + + @property + def unk_id(self): + self.check_special_token_id_getting(self.unk_token, 'unk_id') + return self.vocab[self.unk_token] + + @property + def mask_id(self): + self.check_special_token_id_getting(self.mask_token, 'mask_id') + return self.vocab[self.mask_token] + + @property + def sep_id(self): + self.check_special_token_id_getting(self.sep_token, 'sep_id') + return self.vocab[self.sep_token] + + @property + def cls_id(self): + self.check_special_token_id_getting(self.cls_token, 'cls_id') + return self.vocab[self.cls_token] + + @staticmethod + def create_special_tokens_dict( + mask_token: Optional[str] = None, + bos_token: Optional[str] = None, + eos_token: Optional[str] = None, + pad_token: Optional[str] = None, + sep_token: Optional[str] = None, + cls_token: Optional[str] = None, + unk_token: Optional[str] = None, + ): + special_tokens_dict = {} + for value, name in zip( + [pad_token, unk_token, bos_token, eos_token, sep_token, mask_token, cls_token], + ['pad_token', 'unk_token', 'bos_token', 'eos_token', 'sep_token', 'mask_token', 'cls_token'], + ): + if value is not None: + if not isinstance(value, str): + raise ValueError( + f"The type of parameter `{name}` has to be `None` or `str`, found `{type(value)}`" + ) + elif len(value) == 0: + raise ValueError(f"If the parameter `{name}` is `str`, then its length has to be nonzero.") + elif value in special_tokens_dict.values(): + other_name = None + for k, v in special_tokens_dict.items(): + if v == value: + other_name = k + raise ValueError( + f"The value {repr(value)} of special token `{name}` is the same as the value of special token " + f"`{other_name}`." + ) + special_tokens_dict[name] = value + return special_tokens_dict + + @staticmethod + def check_characters_to_exclude_from_vocabulary(characters_to_exclude_from_vocabulary): + for i, char in enumerate(characters_to_exclude_from_vocabulary): + if not isinstance(char, str): + raise ValueError( + f"Character to exclude from vocabulary has to `str`, whereas an element in position {i} is of " + f"type `{type(char)}`." + ) + elif len(char) != 1: + raise ValueError( + f"A length of an element of `characters_to_exclude_from_vocabulary` parameter has to be 1. " + f"The length of an element in position {i} is {len(char)}." + ) + + @staticmethod + def check_text_and_text_file_name(text, text_file_name): + if text is None and text_file_name is None: + raise ValueError( + f'Exactly one of parameters `text` and `text_file_name` should be provided whereas both parameters ' + f'are `None`.' + ) + if text is not None and text_file_name is not None: + raise ValueError( + f"Exactly one of parameters `text` and `text_file_name` has to be provided, whereas both parameters " + f"are not `None`." + ) + if text is not None: + if not isinstance(text, str): + raise ValueError( + f"Parameter `text` has to be of type `str`, whereas it belongs to type `{type(text)}`." + ) + + @classmethod + def build_vocab( + cls, + save_path: Union[str, bytes, os.PathLike], + text: Optional[str] = None, + text_file_name: Optional[Union[str, bytes, os.PathLike]] = None, + characters_to_exclude: Optional[List[str]] = None, + vocab_size: int = None, + mask_token: Optional[str] = None, + bos_token: Optional[str] = None, + eos_token: Optional[str] = None, + pad_token: Optional[str] = None, + sep_token: Optional[str] = None, + cls_token: Optional[str] = None, + unk_token: Optional[str] = None, + ): + """ + Creates character vocabulary and saves it to file ``save_path``. You should provide one of parameters ``text`` + and ``text_file_name``. The format of created character vocabulary file is following: + ``` + {['mask_token': "ANY NON EMPTY STRING", ]['bos_token': "ANY NON EMPTY STRING", ] and so on} + ' ' + 'e' + ... + ``` + The first line is a JSON which contains special tokens. This special token are set using parameters + ``mas_token``, ``bos_token``, ``eos_token``, ``pad_token``, ``sep_token``, ``cls_token``, ``unk_token``. + Other lines in created vocabulary file are Python string literals containing one character each. + + Args: + save_path: path to the output text file. If ``save_path`` parent directory does not exist it will be created + text: string which characters are used for vocabulary creation. + text_file_name: path to a file which characters are used for vocabulary creation. Use this parameter if + the text in file is too large to be loaded in memory. + characters_to_exclude: a list of characters which will not be added to vocabulary. + vocab_size: vocabulary size. If this parameter is set only most frequent ``vocab_size`` characters are added + to vocabulary. + mask_token: mask token + bos_token: the beginning of sequence token + eos_token: the end of sequence token. Usually equal to sep_token. + pad_token: token to use for padding. + sep_token: token used for separating sequences. + cls_token: class token. Usually equal to bos_token. + unk_token: token to use for unknown tokens. If the parameter ``unk_token`` is set and there is a character + in the input of ``text_to_ids`` of ``text_to_tokens`` methods which is not in the vocabulary, then + such an unknown character is tokenized into ``unk_token``. If the parameter ``unk_token`` is ``False``, + then unknown tokens are discarded. + """ + special_tokens_dict = cls.create_special_tokens_dict( + mask_token, bos_token, eos_token, pad_token, sep_token, cls_token, unk_token + ) + if characters_to_exclude is None: + characters_to_exclude = [] + else: + cls.check_characters_to_exclude_from_vocabulary(characters_to_exclude) + cls.check_text_and_text_file_name(text, text_file_name) + if text is not None: + counter = Counter(text) + else: + assert text_file_name is not None + text_file_name = Path(text_file_name).expanduser() + counter = Counter() + with text_file_name.open(encoding='utf-8') as f: + while True: + segment = f.read(NUMBER_OF_CHARACTERS_READ_BUFFER_SIZE) + if not segment: + break + counter.update(segment) + for char in characters_to_exclude: + if char in counter: + del counter[char] + save_path = Path(save_path).expanduser() + save_path.parent.mkdir(exist_ok=True, parents=True) + with save_path.open('w', encoding='utf-8') as f: + f.write(json.dumps(special_tokens_dict) + '\n') + if vocab_size is None: + for c, _ in sorted(counter.items(), key=lambda x: -x[1]): + f.write(repr(c) + '\n') + else: + vocab_size -= len(special_tokens_dict) + for i, (c, _) in enumerate(sorted(counter.items(), key=lambda x: -x[1])): + if i < vocab_size: + f.write(repr(c) + '\n') + else: + break diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/g2p_table_tokenizer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/g2p_table_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa886ca6b698303f208ada7caa57fa42894c210a --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/g2p_table_tokenizer.py @@ -0,0 +1,106 @@ +import random +from typing import List + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +class G2PTableTokenizer(TokenizerSpec): + def __init__(self, vocab, g2p_table, unk_label=''): + self.vocab = vocab + self.vocab_size = len(self.vocab) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + + self.unk_label = unk_label + self.unk_id = self.vocab[unk_label] + + self.g2p_table = g2p_table + + self.g2p_id_table = {} + for word, phone_seqs in self.g2p_table.items(): + phone_id_seqs = [] + for phone_seq_i in phone_seqs: + phone_seq_id_i = tuple(self.vocab[phone] for phone in phone_seq_i) + assert phone_seq_id_i not in phone_id_seqs + phone_id_seqs.append(phone_seq_id_i) + assert len(phone_id_seqs) > 0 + self.g2p_id_table[word] = phone_id_seqs + + @classmethod + def load(cls, vocab_fp, lexicon_fp): + vocab = load_vocab(vocab_fp) + + g2p_table = {} + with open(lexicon_fp, encoding='utf-8') as f: + for line in f: + line = line.strip() + word, phone_seq = line.split('\t') + if word not in g2p_table: + g2p_table[word] = [] + phone_seq = tuple(phone_seq.split(' ')) + if phone_seq in g2p_table[word]: + print('found duplicated mapping: ', line) + else: + assert len(phone_seq) > 0 + g2p_table[word].append(phone_seq) + return cls(vocab, g2p_table) + + def text_to_tokens(self, text: str) -> List[str]: + text = text.strip() + words = text.split() + tokens = [] + for word_i in words: + phones_list = self.g2p_table.get(word_i) + if phones_list: + if len(phones_list) == 1: + tokens.extend(phones_list[0]) + else: + tokens.extend(random.choice(phones_list)) + else: + tokens.append(self.unk_label) + return tokens + + def tokens_to_text(self, tokens: List[str]) -> str: + return ' '.join(tokens) + + def text_to_ids(self, text: str) -> List[int]: + words = text.split() + ids = [] + for word_i in words: + phone_ids_list = self.g2p_id_table.get(word_i) + if phone_ids_list: + if len(phone_ids_list) == 1: + ids.extend(phone_ids_list[0]) + else: + ids.extend(random.choice(phone_ids_list)) + else: + print('found unk: ', text) + ids.append(self.unk_id) + return ids + + def ids_to_text(self, ids: List[int]) -> str: + ids_ = [id_ for id_ in ids] + return ' '.join(self.ids_to_tokens(ids_)) + + def tokens_to_ids(self, tokens: List[str]) -> List[int]: + return [self.vocab[token] for token in tokens] + + def token_to_id(self, token: str) -> int: + return self.vocab[token] + + def ids_to_tokens(self, ids: List[int]) -> List[str]: + return [self.inv_vocab[id] for id in ids] + + +def load_vocab(path, encoding='utf-8'): + vocab_dict = {} + with open(path, encoding=encoding) as f: + idx = 0 + for line in f: + line = line.strip() + if not line: + continue + token = line.split('\t')[0] + assert token not in vocab_dict + vocab_dict[token] = idx + idx += 1 + return vocab_dict diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc120ba925c87717b1ce1f4fe52f259e7f52122 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -0,0 +1,262 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from typing import Dict, List, Optional, Union + +import sentencepiece + +from nemo.collections.common.parts.utils import if_exist +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.utils import logging + +__all__ = ['SentencePieceTokenizer', 'create_spt_model'] + + +class SentencePieceTokenizer(TokenizerSpec): + ''' + Sentencepiecetokenizer https://github.com/google/sentencepiece. + ''' + + def __init__(self, model_path: str, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None): + """ + Args: + model_path: path to sentence piece tokenizer model. To create the model use create_spt_model() + special_tokens: either list of special tokens or dictionary of token name to token value + """ + if not model_path or not os.path.exists(model_path): + raise ValueError(f"model_path: {model_path} is invalid") + self.tokenizer = sentencepiece.SentencePieceProcessor() + self.tokenizer.Load(model_path) + self.original_vocab_size = self.tokenizer.get_piece_size() + self.vocab_size = self.tokenizer.get_piece_size() + self.special_token_to_id = {} + self.id_to_special_token = {} + if special_tokens: + self.add_special_tokens(special_tokens) + + def text_to_tokens(self, text): + tokens = [] + idx = 0 + last_idx = 0 + + while 1: + indices = {} + + for token in self.special_token_to_id: + try: + indices[token] = text[idx:].index(token) + except ValueError: + continue + + if len(indices) == 0: + break + + next_token = min(indices, key=indices.get) + next_idx = idx + indices[next_token] + + tokens.extend(self.tokenizer.encode_as_pieces(text[idx:next_idx])) + tokens.append(next_token) + idx = next_idx + len(next_token) + + tokens.extend(self.tokenizer.encode_as_pieces(text[idx:])) + return tokens + + def text_to_ids(self, text, nbest_size=None, alpha=None): + ids = [] + idx = 0 + last_idx = 0 + + while 1: + indices = {} + + for token in self.special_token_to_id: + try: + indices[token] = text[idx:].index(token) + except ValueError: + continue + + if len(indices) == 0: + break + + next_token = min(indices, key=indices.get) + next_idx = idx + indices[next_token] + + if nbest_size is not None: + ids.extend(self.tokenizer.sample_encode_as_ids(text[idx:next_idx], nbest_size=nbest_size, alpha=alpha)) + else: + ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx])) + ids.append(self.special_token_to_id[next_token]) + idx = next_idx + len(next_token) + + if nbest_size is not None: + ids.extend(self.tokenizer.sample_encode_as_ids(text[idx:], nbest_size=nbest_size, alpha=alpha)) + else: + ids.extend(self.tokenizer.encode_as_ids(text[idx:])) + return ids + + def tokens_to_text(self, tokens): + return self.tokenizer.decode_pieces(tokens) + + def ids_to_text(self, ids): + text = "" + last_i = 0 + + for i, id in enumerate(ids): + if id in self.id_to_special_token: + text += self.tokenizer.decode_ids(ids[last_i:i]) + " " + text += self.id_to_special_token[id] + " " + last_i = i + 1 + + text += self.tokenizer.decode_ids(ids[last_i:]) + return text.strip() + + def token_to_id(self, token): + if token in self.special_token_to_id: + return self.special_token_to_id[token] + return self.tokenizer.piece_to_id(token) + + def ids_to_tokens(self, ids): + tokens = [] + for id in ids: + if id >= self.original_vocab_size: + tokens.append(self.id_to_special_token[id]) + else: + tokens.append(self.tokenizer.id_to_piece(id)) + return tokens + + def tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: + if isinstance(tokens, str): + tokens = [tokens] + ids = [] + for token in tokens: + ids.append(self.token_to_id(token)) + return ids + + def add_special_tokens(self, special_tokens): + if isinstance(special_tokens, list): + for token in special_tokens: + if ( + self.tokenizer.piece_to_id(token) == self.tokenizer.unk_id() + and token not in self.special_token_to_id + ): + self.special_token_to_id[token] = self.vocab_size + self.id_to_special_token[self.vocab_size] = token + self.vocab_size += 1 + elif isinstance(special_tokens, dict): + for token_name, token in special_tokens.items(): + setattr(self, token_name, token) + if ( + self.tokenizer.piece_to_id(token) == self.tokenizer.unk_id() + and token not in self.special_token_to_id + ): + self.special_token_to_id[token] = self.vocab_size + self.id_to_special_token[self.vocab_size] = token + self.vocab_size += 1 + + @property + def pad_id(self): + return self.tokens_to_ids([self.pad_token])[0] + + @property + def bos_id(self): + return self.tokens_to_ids([self.bos_token])[0] + + @property + def eos_id(self): + return self.tokens_to_ids([self.eos_token])[0] + + @property + def sep_id(self): + return self.tokens_to_ids([self.sep_token])[0] + + @property + def cls_id(self): + return self.tokens_to_ids([self.cls_token])[0] + + +def create_spt_model( + data_file: str, + vocab_size: int, + sample_size: int, + do_lower_case: bool, + tokenizer_type: str = 'unigram', + output_dir: Optional[str] = None, + character_coverage: float = 1.0, +): + """ + Creates sentence piece tokenizer model from data file. + Args: + data_file: data file + vocab_size: vocabulary size + sample_size: maximum size of sentences the trainer loads + do_lower_case: if text should be lower cased before tokenizer model is created + character_coverage: float value between 0 and 1 (as a percentage). For languages with a vast charset, + can be < 1.0, but for all other languages, it should be set as 1.0 + output_dir: folder to save created tokenizer model. If not specified will store model at data_file/../spt folder + """ + + if not data_file or not os.path.exists(data_file): + raise ValueError(f"data_file must be valid file path, but got {data_file}") + data_dir = os.path.dirname(data_file) + vocab = [] + if not output_dir: + output_dir = f'{data_dir}/spt' + if if_exist(output_dir, ['tokenizer.model']): + logging.info(f"tokenizer model {output_dir}/tokenizer.model already exists") + return f'{output_dir}/tokenizer.model', f'{output_dir}/vocab.txt' + logging.info(f'Processing {data_file} and store at {output_dir}') + os.makedirs(output_dir, exist_ok=True) + + cmd = ( + f"--input={data_file} --model_prefix={output_dir}/tokenizer " + f"--vocab_size={vocab_size} " + f"--shuffle_input_sentence=true --hard_vocab_limit=false " + f"--model_type={tokenizer_type} " + f"--character_coverage={character_coverage} " + f"--bos_id=-1 --eos_id=-1" + ) + if do_lower_case: + cmd += " --normalization_rule_name=nmt_nfkc_cf" + + if sample_size > 0: + cmd += f" --input_sentence_size={sample_size}" + + sentencepiece.SentencePieceTrainer.Train(cmd) + + # Add BERT control symbols + tokens = [] + + with open(f"{output_dir}/tokenizer.vocab", "r") as f: + f.readline() # skip first token + + # Read tokens from each line and parse for vocab + for line in f: + piece = line.split("\t")[0] + token = piece[1:] if piece.startswith("▁") else f"##{piece}" + + if len(token) > 0: + tokens.append(token) + else: + tokens.append(piece[0]) + + vocab.extend(tokens) + + # Save vocabulary to output file + vocab_file = f'{output_dir}/vocab.txt' + with open(vocab_file, "w") as f: + for token in vocab: + f.write(f"{token}\n") + return f'{output_dir}/tokenizer.model', vocab_file diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/tokenizer_spec.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/tokenizer_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..252571d76ef205e8f29283c9143eccde11a3d6f8 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/tokenizer_spec.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List + +__all__ = ['TokenizerSpec'] + + +class TokenizerSpec(ABC): + """ + Inherit this class to implement a new tokenizer. + """ + + @abstractmethod + def text_to_tokens(self, text): + pass + + @abstractmethod + def tokens_to_text(self, tokens): + pass + + @abstractmethod + def tokens_to_ids(self, tokens): + pass + + @abstractmethod + def ids_to_tokens(self, ids): + pass + + @abstractmethod + def text_to_ids(self, text): + pass + + @abstractmethod + def ids_to_text(self, ids): + pass + + def add_special_tokens(self, special_tokens: List[str]): + raise NotImplementedError("To be implemented") + + @property + def name(self): + return type(self).__name__ diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/word_tokenizer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/word_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3431af9d734707f63c8d68f38a090cf9f76478b --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/word_tokenizer.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer + +__all__ = ['WordTokenizer'] + + +class WordTokenizer(CharTokenizer): + "Tokenizes at word boundary" + + def __init__( + self, + vocab_file: str, + mask_token: Optional[str] = None, + bos_token: Optional[str] = None, + eos_token: Optional[str] = None, + pad_token: Optional[str] = None, + sep_token: Optional[str] = None, + cls_token: Optional[str] = None, + unk_token: Optional[str] = None, + ): + """ + Args: + vocab_file: path to file with vocabulary which consists + of characters separated by \n + mask_token: mask token + bos_token: the beginning of sequence token + eos_token: the end of sequence token. Usually equal to sep_token + pad_token: token to use for padding + sep_token: token used for separating sequences + cls_token: class token. Usually equal to bos_token + unk_token: token to use for unknown tokens + """ + + super().__init__( + vocab_file=vocab_file, + mask_token=mask_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + ) + + def text_to_tokens(self, text): + token_candidates = text.strip().split() + tokens = [] + for token in token_candidates: + if token in self.vocab: + tokens.append(token) + else: + tokens.append(self.unk_token) + return tokens + + def ids_to_text(self, ids): + ids_ = [id_ for id_ in ids if id_ not in self.special_tokens] + return " ".join(self.ids_to_tokens(ids_)) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/constants.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..f678ea13a2fbfc3c0efa68d180a72f1e9e87e44c --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/constants.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +NEMO_ENV_VARNAME_ENABLE_COLORING = "NEMO_ENABLE_COLORING" +NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR = "NEMO_REDIRECT_LOGS_TO_STDERR" +NEMO_ENV_VARNAME_TESTING = "NEMO_TESTING" # Set to True to enable nemo.util.logging's debug mode +NEMO_ENV_VARNAME_VERSION = "NEMO_EXPM_VERSION" # Used for nemo.utils.exp_manager versioning diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44d13c9efd936f2d6e602e486943829f5b3b613c --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo.core.neural_types +from nemo.core.classes import * diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9226c0ea84ff8c3d5926af6f753589ed94370b2 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.core.classes.common import FileIO, Model, Serialization, Typing, is_typecheck_enabled, typecheck +from nemo.core.classes.dataset import Dataset, IterableDataset +from nemo.core.classes.exportable import Exportable, ExportFormat +from nemo.core.classes.loss import Loss +from nemo.core.classes.modelPT import ModelPT +from nemo.core.classes.module import NeuralModule diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/common.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b0aa82481ccb739a3556f51eb07eea37985b4d94 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/common.py @@ -0,0 +1,553 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Interfaces common to all Neural Modules and Models.""" +import hashlib +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Union + +import hydra +import wrapt +from omegaconf import DictConfig, OmegaConf + +import nemo +from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult +from nemo.utils import logging +from nemo.utils.cloud import maybe_download_from_cloud +from nemo.utils.model_utils import maybe_update_config_version + +__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck'] + + +_TYPECHECK_ENABLED = True + + +def is_typecheck_enabled(): + """ + Getter method for typechecking state. + """ + return _TYPECHECK_ENABLED + + +class Typing(ABC): + """ + An interface which endows module with neural types + """ + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + """Define these to enable input neural type checks""" + return None + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define these to enable output neural type checks""" + return None + + def _validate_input_types(self, input_types=None, **kwargs): + """ + This function does a few things. + 1) It ensures that len(self.input_types ) <= len(kwargs) <= len(self.input_types). + 2) For each (keyword name, keyword value) passed as input to the wrapped function: + - Check if the keyword name exists in the list of valid self.input_types names. + - Check if keyword value has the `neural_type` property. + - If it does, then perform a comparative check and assert that neural types + are compatible (SAME or GREATER). + - Check if keyword value is a container type (list or tuple). If yes, + then perform the elementwise test of neural type above on each element + of the nested structure, recursively. + + Args: + input_types: Either the `input_types` defined at class level, or the local function + overridden type definition. + kwargs: Dictionary of argument_name:argument_value pairs passed to the wrapped + function upon call. + """ + # TODO: Properly implement this + if input_types is not None: + total_input_types = len(input_types) + mandatory_input_types = len( + [type_val for type_key, type_val in input_types.items() if not type_val.optional] + ) + + if len(kwargs) < mandatory_input_types or len(kwargs) > total_input_types: + raise TypeError( + f"Number of input arguments provided ({len(kwargs)}) is not as expected. Function has " + f"{total_input_types} total inputs with {mandatory_input_types} mandatory inputs." + ) + + for key, value in kwargs.items(): + # Check if keys exists in the defined input types + if key not in input_types: + raise TypeError( + f"Input argument {key} has no corresponding input_type match. " + f"Existing input_types = {input_types.keys()}" + ) + + # Perform neural type check + if hasattr(value, 'neural_type') and not input_types[key].compare(value.neural_type) in ( + NeuralTypeComparisonResult.SAME, + NeuralTypeComparisonResult.GREATER, + ): + error_msg = [ + f"{input_types[key].compare(value.neural_type)} :", + f"Input type expected : {input_types[key]}", + f"Input type found : {value.neural_type}", + ] + for i, dict_tuple in enumerate(input_types[key].elements_type.type_parameters.items()): + error_msg.insert(i + 2, f' input param_{i} : {dict_tuple[0]}: {dict_tuple[1]}') + for i, dict_tuple in enumerate(value.neural_type.elements_type.type_parameters.items()): + error_msg.append(f' input param_{i} : {dict_tuple[0]}: {dict_tuple[1]}') + raise TypeError("\n".join(error_msg)) + + # Perform input ndim check + if hasattr(value, 'shape'): + value_shape = value.shape + type_shape = input_types[key].axes + name = key + + if type_shape is not None and len(value_shape) != len(type_shape): + raise TypeError( + f"Input shape mismatch occured for {name} in module {self.__class__.__name__} : \n" + f"Input shape expected = {input_types[key].axes} | \n" + f"Input shape found : {value_shape}" + ) + + # Perform recursive neural type check for homogeneous elements + elif isinstance(value, list) or isinstance(value, tuple): + for ind, val in enumerate(value): + self.__check_neural_type(val, input_types[key], name=key) + + def _attach_and_validate_output_types(self, out_objects, output_types=None): + """ + This function does a few things. + 1) It ensures that len(out_object) == len(self.output_types). + 2) If the output is a tensor (or list/tuple of list/tuple ... of tensors), it + attaches a neural_type to it. For objects without the neural_type attribute, + such as python objects (dictionaries and lists, primitive data types, structs), + no neural_type is attached. + + Note: tensor.neural_type is only checked during _validate_input_types which is + called prior to forward(). + + Args: + output_types: Either the `output_types` defined at class level, or the local function + overridden type definition. + out_objects: The outputs of the wrapped function. + """ + # TODO: Properly implement this + if output_types is not None: + out_types_list = list(output_types.items()) + + # First convert all outputs to list/tuple format to check correct number of outputs + if type(out_objects) in (list, tuple): + out_container = out_objects + else: + out_container = [out_objects] + + if len(output_types) != len(out_container): + raise TypeError( + "Number of output arguments provided ({}) is not as expected ({})".format( + len(out_container), len(output_types) + ) + ) + + # Attach types recursively, if possible + if not isinstance(out_objects, tuple) and not isinstance(out_objects, list): + try: + out_objects.neural_type = out_types_list[0][1] + except Exception: + pass + + # Perform output ndim check + if hasattr(out_objects, 'shape'): + value_shape = out_objects.shape + type_shape = out_types_list[0][1].axes + name = out_types_list[0][0] + + if type_shape is not None and len(value_shape) != len(type_shape): + raise TypeError( + f"Output shape mismatch occured for {name} in module {self.__class__.__name__} : \n" + f"Output shape expected = {type_shape} | \n" + f"Output shape found : {value_shape}" + ) + else: + for ind, res in enumerate(out_objects): + self.__attach_neural_type(res, out_types_list[ind][1], name=out_types_list[ind][0]) + + def __check_neural_type(self, obj, type_val, name=None): + if isinstance(obj, tuple) or isinstance(obj, list): + for elem in obj: + self.__check_neural_type(elem, type_val, name=name) + return # after processing nest, return to avoid testing nest itself + + if hasattr(obj, 'neural_type') and not type_val.compare(obj.neural_type) in ( + NeuralTypeComparisonResult.SAME, + NeuralTypeComparisonResult.GREATER, + ): + raise TypeError( + f"{type_val.compare(obj.neural_type)} : \n" + f"Input type expected = {type_val} | \n" + f"Input type found : {obj.neural_type}" + ) + + # Perform input ndim check + if hasattr(obj, 'shape'): + value_shape = obj.shape + type_shape = type_val.axes + + if type_shape is not None and len(value_shape) != len(type_shape): + raise TypeError( + f"Input shape mismatch occured for {name} in module {self.__class__.__name__} : \n" + f"Input shape expected = {type_shape} | \n" + f"Input shape found : {value_shape}" + ) + + def __attach_neural_type(self, obj, type_val, name=None): + if isinstance(obj, tuple) or isinstance(obj, list): + for elem in obj: + self.__attach_neural_type(elem, type_val, name=name) + return # after processing nest, return to avoid argument insertion into nest itself + + try: + obj.neural_type = type_val + except Exception: + pass + + # Perform output ndim check + if hasattr(obj, 'shape'): + value_shape = obj.shape + type_shape = type_val.axes + + if type_shape is not None and len(value_shape) != len(type_shape): + raise TypeError( + f"Output shape mismatch occured for {name} in module {self.__class__.__name__} : \n" + f"Output shape expected = {type_shape} | \n" + f"Output shape found : {value_shape}" + ) + + +class Serialization(ABC): + @classmethod + def from_config_dict(cls, config: DictConfig): + """Instantiates object using DictConfig-based configuration""" + # Resolve the config dict + if isinstance(config, DictConfig): + config = OmegaConf.to_container(config, resolve=True) + config = OmegaConf.create(config) + OmegaConf.set_struct(config, True) + + config = maybe_update_config_version(config) + + if ('cls' in config or 'target' in config) and 'params' in config: + # regular hydra-based instantiation + instance = hydra.utils.instantiate(config=config) + elif '_target_' in config: + # regular hydra-based instantiation + instance = hydra.utils.instantiate(config=config) + else: + # models are handled differently for now + instance = cls(cfg=config) + + if not hasattr(instance, '_cfg'): + instance._cfg = config + return instance + + def to_config_dict(self) -> DictConfig: + """Returns object's configuration to config dictionary""" + if hasattr(self, '_cfg') and self._cfg is not None and isinstance(self._cfg, DictConfig): + # Resolve the config dict + config = OmegaConf.to_container(self._cfg, resolve=True) + config = OmegaConf.create(config) + OmegaConf.set_struct(config, True) + + config = maybe_update_config_version(config) + + self._cfg = config + + return self._cfg + else: + raise NotImplementedError( + 'to_config_dict() can currently only return object._cfg but current object does not have it.' + ) + + +class FileIO(ABC): + def save_to(self, save_path: str): + """Saves module/model with weights""" + raise NotImplementedError() + + @classmethod + def restore_from( + cls, + restore_path: str, + override_config_path: Optional[str] = None, + map_location: Optional['torch.device'] = None, + strict: bool = True, + ): + """Restores module/model with weights""" + raise NotImplementedError() + + @classmethod + def from_config_file(cls, path2yaml_file: str): + """ + Instantiates an instance of NeMo Model from YAML config file. + Weights will be initialized randomly. + Args: + path2yaml_file: path to yaml file with model configuration + + Returns: + + """ + if issubclass(cls, Serialization): + conf = OmegaConf.load(path2yaml_file) + return cls.from_config_dict(config=conf) + else: + raise NotImplementedError() + + def to_config_file(self, path2yaml_file: str): + """ + Saves current instance's configuration to YAML config file. Weights will not be saved. + Args: + path2yaml_file: path2yaml_file: path to yaml file where model model configuration will be saved + + Returns: + """ + if hasattr(self, '_cfg'): + self._cfg = maybe_update_config_version(self._cfg) + with open(path2yaml_file, 'w', encoding='utf-8') as fout: + OmegaConf.save(config=self._cfg, f=fout, resolve=True) + else: + raise NotImplementedError() + + +@dataclass +class PretrainedModelInfo: + pretrained_model_name: str + description: str + location: str + class_: 'Model' = None + + +class Model(Typing, Serialization, FileIO): + """ + Abstract class offering interface which should be implemented by all NeMo models. + """ + + @classmethod + @abstractmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + Should list all pre-trained models available via NVIDIA NGC cloud + + Returns: + A list of PretrainedModelInfo entries + """ + pass + + @classmethod + def get_available_model_names(cls) -> List[str]: + """ + Returns the list of model names available via NVIDIA NGC cloud, + to get the complete model description use list_available_models() + Returns: + A list of model names + """ + model_names = [] + if cls.list_available_models() is not None: + model_names = [model.pretrained_model_name for model in cls.list_available_models()] + return model_names + + @classmethod + def from_pretrained( + cls, + model_name: str, + refresh_cache: bool = False, + override_config_path: Optional[str] = None, + map_location: Optional['torch.device'] = None, + strict: bool = True, + ): + """ + Instantiates an instance of NeMo from NVIDIA NGC cloud + Use restore_from() to instantiate from a local .nemo file. + Args: + model_name: string key which will be used to find the module. + refresh_cache: If set to True, then when fetching from cloud, this will re-fetch the file + from cloud even if it is already found in a cache locally. + override_config_path: path to a yaml config that will override the internal + config file + map_location: Optional torch.device() to map the instantiated model to a device. + By default (None), it will select a GPU if available, falling back to CPU otherwise. + strict: Passed to torch.load_state_dict + + Returns: + A model instance of a particular model class + """ + location_in_the_cloud = None + description = None + if cls.list_available_models() is not None: + for pretrained_model_info in cls.list_available_models(): + if pretrained_model_info.pretrained_model_name == model_name: + location_in_the_cloud = pretrained_model_info.location + description = pretrained_model_info.description + class_ = pretrained_model_info.class_ + if location_in_the_cloud is None: + raise FileNotFoundError( + f"Model {model_name} was not found. Check cls.list_available_models() for the list of all available models." + ) + filename = location_in_the_cloud.split("/")[-1] + url = location_in_the_cloud.replace(filename, "") + cache_dir = Path.joinpath(Path.home(), f'.cache/torch/NeMo/NeMo_{nemo.__version__}/{filename[:-5]}') + # If either description and location in the cloud changes, this will force re-download + cache_subfolder = hashlib.md5((location_in_the_cloud + description).encode('utf-8')).hexdigest() + # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True + nemo_model_file_in_cache = maybe_download_from_cloud( + url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache + ) + logging.info("Instantiating model from pre-trained checkpoint") + if class_ is None: + class_ = cls + instance = class_.restore_from( + restore_path=nemo_model_file_in_cache, + override_config_path=override_config_path, + map_location=map_location, + strict=strict, + ) + return instance + + +class typecheck: + class TypeState(Enum): + """ + Placeholder to denote the default value of type information provided. + If the constructor of this decorator is used to override the class level type definition, + this enum value indicate that types will be overridden. + """ + + UNINITIALIZED = 0 + + def __init__( + self, + input_types: Union[TypeState, Dict[str, NeuralType]] = TypeState.UNINITIALIZED, + output_types: Union[TypeState, Dict[str, NeuralType]] = TypeState.UNINITIALIZED, + ): + """ + A decorator which performs input-output neural type checks, and attaches + neural types to the output of the function that it wraps. + + Requires that the class inherit from `nemo.core.Typing` in order to perform + type checking, and will raise an error if that is not the case. + + # Usage (Class level type support) + @typecheck() + def fn(self, arg1, arg2, ...): + ... + + # Usage (Function level type support) + @typecheck(input_types=..., output_types=...) + def fn(self, arg1, arg2, ...): + ... + + Points to be noted: + 1) The brackets () in `@typecheck()` are necessary. + + You will encounter a TypeError: __init__() takes 1 positional argument but X + were given without those brackets. + + 2) The function can take any number of positional arguments during definition. + + When you call this function, all arguments must be passed using kwargs only. + + """ + self.input_types = input_types + self.output_types = output_types + + if input_types == self.TypeState.UNINITIALIZED: + self.input_override = False + else: + self.input_override = True + + if output_types == self.TypeState.UNINITIALIZED: + self.output_override = False + else: + self.output_override = True + + @wrapt.decorator(enabled=is_typecheck_enabled) + def __call__(self, wrapped, instance: Typing, args, kwargs): + if instance is None: + raise RuntimeError("Only classes which inherit nemo.core.Typing can use this decorator !") + + if not isinstance(instance, Typing): + raise RuntimeError("Only classes which inherit nemo.core.Typing can use this decorator !") + + if hasattr(instance, 'input_ports') or hasattr(instance, 'output_ports'): + raise RuntimeError( + "Typing requires override of `input_types()` and `output_types()`, " + "not `input_ports() and `output_ports()`" + ) + + # Preserve type information + if self.input_types is typecheck.TypeState.UNINITIALIZED: + self.input_types = instance.input_types + + if self.output_types is typecheck.TypeState.UNINITIALIZED: + self.output_types = instance.output_types + + # Resolve global type or local overridden type + if self.input_override: + input_types = self.input_types + else: + input_types = instance.input_types + + if self.output_override: + output_types = self.output_types + else: + output_types = instance.output_types + + # If types are not defined, skip type checks and just call the wrapped method + if input_types is None and output_types is None: + return wrapped(*args, **kwargs) + + # Check that all arguments are kwargs + if input_types is not None and len(args) > 0: + raise TypeError("All arguments must be passed by kwargs only for typed methods") + + # Perform rudimentary input checks here + instance._validate_input_types(input_types=input_types, **kwargs) + + # Call the method - this can be forward, or any other callable method + outputs = wrapped(*args, **kwargs) + + instance._attach_and_validate_output_types(output_types=output_types, out_objects=outputs) + + return outputs + + @staticmethod + def set_typecheck_enabled(enabled: bool = True): + global _TYPECHECK_ENABLED + _TYPECHECK_ENABLED = enabled + + @staticmethod + @contextmanager + def disable_checks(): + typecheck.set_typecheck_enabled(enabled=False) + try: + yield + finally: + typecheck.set_typecheck_enabled(enabled=True) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/dataset.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..51bd46ef79010660616eba88461b1c74ce8fbd04 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/dataset.py @@ -0,0 +1,109 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +from torch.utils import data + +from nemo.core.classes import Serialization, Typing, typecheck + +__all__ = ['Dataset', 'IterableDataset'] + + +class Dataset(data.Dataset, Typing, Serialization): + """Dataset with output ports + + Please Note: Subclasses of IterableDataset should *not* implement input_types. + """ + + def _collate_fn(self, batch): + """ + A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + return data.dataloader.default_collate(batch) + + @typecheck() + def collate_fn(self, batch): + """ + This is the method that user pass as functor to DataLoader. + The method optionally performs neural type checking and add types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + # Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns: + Collated batch, with or without types. + """ + if self.input_types is not None: + raise TypeError("Datasets should not implement `input_types` as they are not checked") + + # Simply forward the inner `_collate_fn` + return self._collate_fn(batch) + + +class IterableDataset(data.IterableDataset, Typing, Serialization): + """Iterable Dataset with output ports + + Please Note: Subclasses of IterableDataset should *not* implement input_types. + """ + + def _collate_fn(self, batch): + """ + A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + return data.dataloader.default_collate(batch) + + @typecheck() + def collate_fn(self, batch): + """ + This is the method that user pass as functor to DataLoader. + The method optionally performs neural type checking and add types to the outputs. + + # Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns: + Collated batch, with or without types. + """ + if self.input_types is not None: + raise TypeError("Datasets should not implement `input_types` as they are not checked") + + # Simply forward the inner `_collate_fn` + return self._collate_fn(batch) + + +@dataclass +class DatasetConfig: + """ + + """ + + # ... + batch_size: int = 32 + drop_last: bool = False + shuffle: bool = False + num_workers: Optional[int] = None + pin_memory: bool = True diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/exportable.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/exportable.py new file mode 100644 index 0000000000000000000000000000000000000000..fddc2f4df5f8da51b54c3ada2618c10592b865ba --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/exportable.py @@ -0,0 +1,212 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC +from collections import defaultdict +from enum import Enum +from typing import Dict + +import onnx +import torch + +from nemo.core.classes import typecheck +from nemo.core.neural_types import AxisKind, NeuralType +from nemo.utils.export_utils import replace_for_export + +__all__ = ['ExportFormat', 'Exportable'] + + +class ExportFormat(Enum): + """Which format to use when exporting a Neural Module for deployment""" + + ONNX = (1,) + TORCHSCRIPT = (2,) + + +_EXT_DICT = { + ".pt": ExportFormat.TORCHSCRIPT, + ".onnx": ExportFormat.ONNX, +} + + +class Exportable(ABC): + """ + This Interface should be implemented by particular classes derived from nemo.core.NeuralModule or nemo.core.ModelPT. + It gives these entities ability to be exported for deployment to formats such as ONNX. + """ + + @staticmethod + def get_format(filename: str): + _, ext = os.path.splitext(filename) + try: + return _EXT_DICT[ext] + except KeyError: + raise ValueError(f"Export file {filename} extension does not correspond to any export format!") + + def export( + self, + output: str, + input_example=None, + output_example=None, + verbose=False, + export_params=True, + do_constant_folding=True, + keep_initializers_as_inputs=False, + onnx_opset_version: int = 12, + try_script: bool = False, + set_eval: bool = True, + check_trace: bool = True, + use_dynamic_axes: bool = True, + ): + try: + # Disable typechecks + typecheck.set_typecheck_enabled(enabled=False) + + # Set module to eval mode + if set_eval: + self.eval() + + format = self.get_format(output) + self._prepare_for_export() + + if input_example is not None: + _in_example = input_example + else: + _in_example = self.input_example() + + if output_example is None: + _out_example = self.forward(*_in_example) + + if not (hasattr(self, 'input_types') and hasattr(self, 'output_types')): + raise NotImplementedError('For export to work you must define input and output types') + input_names = list(self.input_types.keys()) + output_names = list(self.output_types.keys()) + # dynamic axis is a mapping from input/output_name => list of "dynamic" indices + dynamic_axes = defaultdict(list) + + # extract dynamic axes and remove unnecessary inputs/outputs + # for input_ports + for _name, ntype in self.input_types.items(): + if _name in self.disabled_deployment_input_names: + input_names.remove(_name) + continue + if use_dynamic_axes: + dynamic_axes = {**dynamic_axes, **self._extract_dynamic_axes(_name, ntype)} + # for output_ports + for _name, ntype in self.output_types.items(): + if _name in self.disabled_deployment_output_names: + output_names.remove(_name) + continue + if use_dynamic_axes: + dynamic_axes = {**dynamic_axes, **self._extract_dynamic_axes(_name, ntype)} + + if len(dynamic_axes) == 0: + dynamic_axes = None + + with torch.jit.optimized_execution(True): + jitted_model = None + if try_script: + try: + jitted_model = torch.jit.script(self) + except Exception as e: + print("jit.script() failed!", e) + if _in_example is None: + raise ValueError(f'Example input is None, but jit.script() has failed or not tried') + + if isinstance(_in_example, Dict): + _in_example = tuple(_in_example.values()) + + if jitted_model is None: + jitted_model = torch.jit.trace(self, _in_example, check_trace=check_trace) + + if format == ExportFormat.TORCHSCRIPT: + jitted_model.save(output) + assert os.path.exists(output) + elif format == ExportFormat.ONNX: + if _out_example is None: + if isinstance(_in_example, tuple): + _out_example = self.forward(*_in_example) + else: + _out_example = self.forward(_in_example) + + torch.onnx.export( + jitted_model, + _in_example, + output, + input_names=input_names, + output_names=output_names, + verbose=verbose, + export_params=export_params, + do_constant_folding=do_constant_folding, + keep_initializers_as_inputs=keep_initializers_as_inputs, + dynamic_axes=dynamic_axes, + opset_version=onnx_opset_version, + example_outputs=_out_example, + ) + + # Verify the model can be read, and is valid + onnx_model = onnx.load(output) + onnx.checker.check_model(onnx_model, full_check=True) + return onnx_model + else: + raise ValueError(f'Encountered unknown export format {format}.') + finally: + typecheck.set_typecheck_enabled(enabled=True) + return [output] # Subclasses may create more than one file. + + @property + def disabled_deployment_input_names(self): + """Implement this method to return a set of input names disabled for export""" + return set() + + @property + def disabled_deployment_output_names(self): + """Implement this method to return a set of output names disabled for export""" + return set() + + @property + def supported_export_formats(self): + """Implement this method to return a set of export formats supported. Default is all types.""" + return set([ExportFormat.ONNX, ExportFormat.TORCHSCRIPT]) + + @staticmethod + def _extract_dynamic_axes(name: str, ntype: NeuralType): + """ + Implement this method to provide dynamic axes id for ONNX export. + By default, this method will extract BATCH and TIME dimension ids from each provided input/output name argument. + + For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] + shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes + as they can change from call to call during inference. + + Args: + name: Name of input or output parameter + ntype: Corresponding Neural Type + + Returns: + + """ + dynamic_axes = defaultdict(list) + if ntype.axes: + for ind, axis in enumerate(ntype.axes): + if axis.kind in [AxisKind.Batch, AxisKind.Time, AxisKind.Width, AxisKind.Height]: + dynamic_axes[name].append(ind) + return dynamic_axes + + def _prepare_for_export(self): + """ + Override this method to prepare module for export. This is in-place operation. + Base version does common necessary module replacements (Apex etc) + """ + replace_for_export(self) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/loss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ede4049c5374548cecc389cd8529e8db576dfd --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/loss.py @@ -0,0 +1,26 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo.core.classes.common import Serialization, Typing + +__all__ = ['Loss'] + + +class Loss(torch.nn.modules.loss._Loss, Typing, Serialization): + """Inherit this class to implement custom loss.""" + + def __init__(self, **kwargs): + super(Loss, self).__init__(**kwargs) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/modelPT.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/modelPT.py new file mode 100644 index 0000000000000000000000000000000000000000..beb75d4341337a0e9ad55bccacbf51fae94754bf --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/modelPT.py @@ -0,0 +1,1303 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import os +import shutil +import tarfile +import tempfile +from abc import abstractmethod +from os import path +from typing import Callable, Dict, List, Optional, Union + +import hydra +import torch +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.utilities import rank_zero_only + +from nemo.core import optim +from nemo.core.classes.common import Model +from nemo.core.config.modelPT import ModelPTConfig +from nemo.core.optim import prepare_lr_scheduler +from nemo.utils import config_utils, logging, model_utils +from nemo.utils.app_state import AppState +from nemo.utils.get_rank import is_global_rank_zero + +# Need to set them before EFF import as it is using them. +_MODEL_CONFIG_YAML = "model_config.yaml" +_MODEL_WEIGHTS = "model_weights.ckpt" + +try: + # Try to import strategies for .nemo archive. + from eff.cookbooks import NeMoCookbook + + _EFF_PRESENT_ = True +except ImportError: + _EFF_PRESENT_ = False + +__all__ = ['ModelPT'] + +""" +Internal global flags that determine core functionality of ModelPT. + +_MODEL_IS_RESTORED: + This flag determines the context of the model - whether the model is currently being + restored or not. + - When set, it can be assumed that the model's will disable all automatic methods - + setup_training_data(), setup_validation/test_data() and their multi equivalents. + - If a model is being restored from a archive file (tarfile), it can be assumed that + under this context, the cwd is *inside* the tarfile itself. + +_MODEL_RESTORE_PATH: + A string path to a a file from which the model is being restored. + This file can either be a PyTorch Lightning Checkpoint, or a archive (tarfile) that contains + artifact objects. + If it is an archive file, during restoration, the cwd will be temporarily moved to inside the + archive itself. + +_MODEL_EFF_SAVE: + A global flag that switches the format of the archive file that will be stored. + This flag only enables EFF when the package support is available. +""" +_MODEL_IS_RESTORED = False +_MODEL_RESTORE_PATH = None +_MODEL_EFF_SAVE = True + + +class ModelPT(LightningModule, Model): + """ + Interface for Pytorch-lightning based NeMo models + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + """ + Base class from which all NeMo models should inherit + + Args: + cfg (DictConfig): configuration object. + The cfg object should have (optionally) the following sub-configs: + + * train_ds - to instantiate training dataset + * validation_ds - to instantiate validation dataset + * test_ds - to instantiate testing dataset + * optim - to instantiate optimizer with learning rate scheduler + + trainer (Optional): Pytorch Lightning Trainer instance + """ + if trainer is not None and not isinstance(trainer, Trainer): + raise ValueError( + f"trainer constructor argument must be either None or pytroch_lightning.Trainer. But got {type(trainer)} instead." + ) + super().__init__() + + # Convert config to a DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + + # Convert config to support Hydra 1.0+ instantiation + cfg = model_utils.maybe_update_config_version(cfg) + + if 'target' not in cfg: + # This is for Jarvis service. + OmegaConf.set_struct(cfg, False) + cfg.target = "{0}.{1}".format(self.__class__.__module__, self.__class__.__name__) + OmegaConf.set_struct(cfg, True) + + self._cfg = cfg + + self.save_hyperparameters(self._cfg) + self._train_dl = None + self._validation_dl = None + self._test_dl = None + self._optimizer = None + self._scheduler = None + self._trainer = trainer + + # Set device_id in AppState + if torch.cuda.is_available() and torch.cuda.current_device() is not None: + app_state = AppState() + app_state.device_id = torch.cuda.current_device() + + if self._cfg is not None and not self._is_model_being_restored(): + if 'train_ds' in self._cfg and self._cfg.train_ds is not None: + self.setup_training_data(self._cfg.train_ds) + + if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None: + self.setup_multiple_validation_data(val_data_config=None) + + if 'test_ds' in self._cfg and self._cfg.test_ds is not None: + self.setup_multiple_test_data(test_data_config=None) + + else: + if 'train_ds' in self._cfg and self._cfg.train_ds is not None: + logging.warning( + f"Please call the ModelPT.setup_training_data() method " + f"and provide a valid configuration file to setup the train data loader.\n" + f"Train config : \n{OmegaConf.to_yaml(self._cfg.train_ds)}" + ) + + if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None: + logging.warning( + f"Please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method " + f"and provide a valid configuration file to setup the validation data loader(s). \n" + f"Validation config : \n{OmegaConf.to_yaml(self._cfg.validation_ds)}" + ) + + if 'test_ds' in self._cfg and self._cfg.test_ds is not None: + logging.warning( + f"Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method " + f"and provide a valid configuration file to setup the test data loader(s).\n" + f"Test config : \n{OmegaConf.to_yaml(self._cfg.test_ds)}" + ) + + # ModelPT wrappers over subclass implementations + self.training_step = model_utils.wrap_training_step(self.training_step) + + def register_artifact(self, config_path: str, src: str): + """ + Register model artifacts with this function. These artifacts (files) will be included inside .nemo file + when model.save_to("mymodel.nemo") is called. + + WARNING: If you specified /example_folder/example.txt but ./example.txt exists, then ./example.txt will be used. + + Args: + config_path: config path where artifact is used + src: path to the artifact + + Returns: + path to be used when accessing artifact. If src='' or None then '' or None will be returned + """ + if not hasattr(self, 'artifacts'): + self.artifacts = {} + if self.artifacts is None: + self.artifacts = {} + if src is not None and src.strip() != '': + archive_item = model_utils.ArtifactItem() + + basename_src = os.path.basename(src) + # filename exists in current workdir - use it and raise warning + # this case is during model restoration or when file is written to cwd. + if os.path.exists(basename_src): + logging.warning(f"Using {os.path.abspath(basename_src)} instead of {src}.") + used_src = basename_src + + # Case: register_artifact() called inside restoration context + if self._is_model_being_restored() and self._is_restore_type_tarfile(): + archive_item.path_type = model_utils.ArtifactPathType.TAR_PATH + else: + archive_item.path_type = model_utils.ArtifactPathType.LOCAL_PATH + + else: + used_src = src + archive_item.path_type = model_utils.ArtifactPathType.LOCAL_PATH + + if not os.path.exists(used_src): + # File not found in local path or by basename + # Try to locate it inside the .nemo archive (if model was restored) + # Case: register_artifact() called outside restoration context + if self._is_restore_type_tarfile(): + # Get path where the command is executed - the artifacts will be "retrieved" there + # (original .nemo behavior) + cwd = os.getcwd() + try: + # Step into the nemo archive to try and find the file + with tempfile.TemporaryDirectory() as tmpdir: + self.__unpack_nemo_file(path2file=_MODEL_RESTORE_PATH, out_folder=tmpdir) + os.chdir(tmpdir) + if os.path.exists(basename_src): + logging.warning(f"Using {os.path.abspath(basename_src)} instead of {src}.") + used_src = basename_src + + archive_item.path = used_src + archive_item.path_type = model_utils.ArtifactPathType.TAR_PATH + else: + # No further action can be taken, file not found anywhere + raise FileNotFoundError( + f"Could not find {used_src} inside " + f"tarfile {_MODEL_RESTORE_PATH} or under local" + ) + finally: + # change back working directory + os.chdir(cwd) + else: + # No further action can be taken, file not found anywhere + raise FileNotFoundError(f"Could not find {used_src}") + else: + # Found filepath + archive_item.path = used_src + + # But disregarding whether you use "local" or "remote" artifact - always store the original path. + # This fixes issues raising when finetuning NLP models that create and register tokenizer vocabs. + if config_path in self.artifacts: + logging.warning( + f"Artifact {config_path} with value '{self.artifacts[config_path]}' " + f"already exists and will be overwritten with value '{src}'!" + ) + + self.artifacts[config_path] = archive_item + return used_src + else: + return src + + def _default_save_to(self, save_path: str): + """ + Saves model instance (weights and configuration) into .nemo file. + You can use "restore_from" method to fully restore instance from .nemo file. + + .nemo file is an archive (tar.gz) with the following: + model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor + model_wights.chpt - model checkpoint + + Args: + save_path: Path to .nemo file where model instance should be saved + """ + with tempfile.TemporaryDirectory() as tmpdir: + config_yaml = path.join(tmpdir, _MODEL_CONFIG_YAML) + model_weights = path.join(tmpdir, _MODEL_WEIGHTS) + + if hasattr(self, 'artifacts') and self.artifacts is not None: + for (conf_path, src) in self.artifacts.items(): # type: (str, model_utils.ArtifactItem) + try: + if src.path_type == model_utils.ArtifactPathType.LOCAL_PATH and os.path.exists(src.path): + shutil.copy2(src.path, tmpdir) + elif src.path_type == model_utils.ArtifactPathType.TAR_PATH: + # Need to step into nemo archive to extract file + # Get path where the command is executed - the artifacts will be "retrieved" there + # (original .nemo behavior) + cwd = os.getcwd() + try: + # Step into the nemo archive to try and find the file + with tempfile.TemporaryDirectory() as archive_dir: + self.__unpack_nemo_file(path2file=_MODEL_RESTORE_PATH, out_folder=archive_dir) + os.chdir(archive_dir) + shutil.copy2(src.path, tmpdir) + finally: + # change back working directory + os.chdir(cwd) + else: + raise ValueError(f"Invalid ArchivePathType found: {src.path_type}") + except Exception: + logging.error(f"Could not copy artifact {src} used in {conf_path}") + + self.to_config_file(path2yaml_file=config_yaml) + torch.save(self.state_dict(), model_weights) + self.__make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir) + + def _eff_save_to(self, save_path: str): + """ + Saves model instance (weights, configuration and artifacts) into an EFF archive using + the default `save_to` recipe from NeMoCookbook. + + .. note:: + For NVIDIA NeMo the EFF archives will also use .nemo postfix. + + Method creates an EFF-based file that is an archive (tar.gz) with the following: + manifest.yaml - yaml file describing the content of the archive. + model_config.yaml - model configuration in .yaml format. + You can deserialize this into cfg argument for model's constructor + model_wights.chpt - model checkpoint + + Args: + save_path: Path to archive file where model instance should be saved. + """ + NeMoCookbook().save_to(obj=self, save_path=save_path) + + @rank_zero_only + def save_to(self, save_path: str): + """ + Saves model instance (weights and configuration) into EFF archive or . + You can use "restore_from" method to fully restore instance from .nemo file. + + .nemo file is an archive (tar.gz) with the following: + model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor + model_wights.chpt - model checkpoint + + Args: + save_path: Path to .nemo file where model instance should be saved + """ + + # Add nemo rank check as well + if not is_global_rank_zero(): + return + + if _EFF_PRESENT_ and self.use_eff_save(): + # Save EFF archive. + self._eff_save_to(save_path) + else: + # Save .nemo tar archive. + self._default_save_to(save_path) + + @classmethod + def _default_restore_from( + cls, + restore_path: str, + override_config_path: Optional[str] = None, + map_location: Optional[torch.device] = None, + strict: bool = False, + ): + """ + Restores model instance (weights and configuration) into .nemo file + Args: + restore_path: path to .nemo file from which model should be instantiated + override_config_path: path to a yaml config that will override the internal + config file + map_location: Optional torch.device() to map the instantiated model to a device. + By default (None), it will select a GPU if available, falling back to CPU otherwise. + strict: Passed to load_state_dict. + + Example: + ``` + model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo') + assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel) + ``` + + Returns: + An instance of type cls + """ + # Get path where the command is executed - the artifacts will be "retrieved" there + # (original .nemo behavior) + cwd = os.getcwd() + + if map_location is None: + if torch.cuda.is_available(): + map_location = torch.device('cuda') + else: + map_location = torch.device('cpu') + + with tempfile.TemporaryDirectory() as tmpdir: + try: + cls._set_model_restore_state(is_being_restored=True) + cls.__unpack_nemo_file(path2file=restore_path, out_folder=tmpdir) + os.chdir(tmpdir) + if override_config_path is None: + config_yaml = path.join(tmpdir, _MODEL_CONFIG_YAML) + else: + config_yaml = override_config_path + conf = OmegaConf.load(config_yaml) + if override_config_path is not None: + # Resolve the override config + conf = OmegaConf.to_container(conf, resolve=True) + conf = OmegaConf.create(conf) + # If override is top level config, extract just `model` from it + if 'model' in conf: + conf = conf.model + model_weights = path.join(tmpdir, _MODEL_WEIGHTS) + OmegaConf.set_struct(conf, True) + instance = cls.from_config_dict(config=conf) + instance = instance.to(map_location) + instance.load_state_dict(torch.load(model_weights, map_location=map_location), strict=strict) + + logging.info(f'Model {cls.__name__} was successfully restored from {restore_path}.') + finally: + cls._set_model_restore_state(is_being_restored=False) + os.chdir(cwd) + + return instance + + @classmethod + def _eff_restore_from( + cls, + restore_path: str, + override_config_path: Optional[str] = None, + map_location: Optional[torch.device] = None, + strict: bool = False, + ): + """ + Restores model instance (weights, configuration and artifacts) from EFF Archive using + the default `restore_from` recipe from NeMoCookbook. + + Args: + restore_path: path to file from which model should be instantiated + override_config_path: path to a yaml config that will override the internal + config file + map_location: Optional torch.device() to map the instantiated model to a device. + By default (None), it will select a GPU if available, falling back to CPU otherwise. + strict: Passed to load_state_dict. + + Returns: + An instance of type cls + """ + return NeMoCookbook().restore_from( + restore_path=restore_path, + obj_cls=cls, + override_config_path=override_config_path, + map_location=map_location, + strict=strict, + ) + + @classmethod + def restore_from( + cls, + restore_path: str, + override_config_path: Optional[str] = None, + map_location: Optional[torch.device] = None, + strict: bool = False, + ): + """ + Restores model instance (weights and configuration) from file. + + The methods tries to load it as EFF archive. + If EFF library is not present in the system, or the indicated file is not EFF archive, + the function defaults to the original .nemo restore method. + + Args: + restore_path: path to .nemo file from which model should be instantiated + override_config_path: path to a yaml config that will override the internal + config file + map_location: Optional torch.device() to map the instantiated model to a device. + By default (None), it will select a GPU if available, falling back to CPU otherwise. + strict: Passed to load_state_dict. + + Example: + ``` + model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo') + assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel) + ``` + + Returns: + An instance of type cls + """ + if not path.exists(restore_path): + raise FileNotFoundError(f"Can't find {restore_path}") + + global _MODEL_RESTORE_PATH + _MODEL_RESTORE_PATH = os.path.abspath(os.path.expanduser(restore_path)) + + if _EFF_PRESENT_: + # Try to load the EFF archive. + try: + return cls._eff_restore_from(restore_path, override_config_path, map_location, strict) + except (FileNotFoundError, TypeError): + # Default to the old .nemo tar archive restore method. + return cls._default_restore_from(restore_path, override_config_path, map_location, strict) + else: + # Load .nemo tar archive using the old restore method. + return cls._default_restore_from(restore_path, override_config_path, map_location, strict) + + @classmethod + def extract_state_dict_from(cls, restore_path: str, save_dir: str, split_by_module: bool = False): + """ + Extract the state dict(s) from a provided .nemo tarfile and save it to a directory. + Args: + restore_path: path to .nemo file from which state dict(s) should be extracted + save_dir: directory in which the saved state dict(s) should be stored + split_by_module: bool flag, which determins whether the output checkpoint should + be for the entire Model, or the individual module's that comprise the Model + + Example: + To convert the .nemo tarfile into a single Model level PyTorch checkpoint + ``` + state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './asr_ckpts) + ``` + + To restore a model from a Model level checkpoint + ``` + model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration + model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt")) + ``` + + To convert the .nemo tarfile into multiple Module level PyTorch checkpoints + ``` + state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './asr_ckpts, + split_by_module=True) + ``` + + To restore a module from a Module level checkpoint + ``` + model = model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration + + # load the individual components + model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt")) + model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt")) + model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt")) + ``` + + Returns: + The state dict that was loaded from the original .nemo checkpoint + """ + if not path.exists(restore_path): + raise FileExistsError(f"Can't find {restore_path}") + + cwd = os.getcwd() + + save_dir = os.path.abspath(save_dir) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + + with tempfile.TemporaryDirectory() as tmpdir: + try: + cls.__unpack_nemo_file(path2file=restore_path, out_folder=tmpdir) + os.chdir(tmpdir) + model_weights = path.join(tmpdir, _MODEL_WEIGHTS) + state_dict = torch.load(model_weights) + + if not split_by_module: + filepath = os.path.join(save_dir, _MODEL_WEIGHTS) + torch.save(state_dict, filepath) + + else: + key_set = set([key.split(".")[0] for key in state_dict.keys()]) + for primary_key in key_set: + inner_keys = [key for key in state_dict.keys() if key.split(".")[0] == primary_key] + state_dict_subset = { + ".".join(inner_key.split(".")[1:]): state_dict[inner_key] for inner_key in inner_keys + } + filepath = os.path.join(save_dir, f"{primary_key}.ckpt") + torch.save(state_dict_subset, filepath) + + logging.info(f'Checkpoints from {restore_path} were successfully extracted into {save_dir}.') + finally: + os.chdir(cwd) + + return state_dict + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: str, + *args, + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + hparams_file: Optional[str] = None, + strict: bool = True, + **kwargs, + ): + """ + Loads ModelPT from checkpoint, with some maintenance of restoration. + For documentation, please refer to LightningModule.load_from_checkpoin() documentation. + """ + checkpoint = None + try: + cls._set_model_restore_state(is_being_restored=True) + + checkpoint = super().load_from_checkpoint( + checkpoint_path=checkpoint_path, + *args, + map_location=map_location, + hparams_file=hparams_file, + strict=strict, + **kwargs, + ) + + finally: + cls._set_model_restore_state(is_being_restored=False) + return checkpoint + + @classmethod + def load_state_from_checkpoint( + cls, + model, + checkpoint_path: str, + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + strict: bool = True, + ): + try: + cls._set_model_restore_state(is_being_restored=True) + + from pytorch_lightning.utilities.cloud_io import load as pl_load + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + # for past checkpoint need to add the new key + if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + + # give model a chance to load something + model.on_load_checkpoint(checkpoint) + + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict'], strict=strict) + finally: + cls._set_model_restore_state(is_being_restored=False) + + @abstractmethod + def setup_training_data(self, train_data_config: Union[DictConfig, Dict]): + """ + Setups data loader to be used in training + + Args: + train_data_layer_config: training data layer parameters. + Returns: + + """ + pass + + @abstractmethod + def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]): + """ + Setups data loader to be used in validation + Args: + + val_data_layer_config: validation data layer parameters. + Returns: + + """ + pass + + def setup_test_data(self, test_data_config: Union[DictConfig, Dict]): + """ + (Optionally) Setups data loader to be used in test + + Args: + test_data_layer_config: test data layer parameters. + Returns: + + """ + raise NotImplementedError() + + def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]): + """ + (Optionally) Setups data loader to be used in validation, with support for multiple data loaders. + + Args: + val_data_layer_config: validation data layer parameters. + """ + # Set some placeholder overriden by helper method + self._val_dl_idx = 0 + self._validation_names = None + self._validation_dl = None # type: torch.utils.data.DataLoader + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + try: + self._multi_dataset_mode = True + model_utils.resolve_validation_dataloaders(model=self) + finally: + self._multi_dataset_mode = False + + if self._validation_names is None: + if self._validation_dl is not None and type(self._validation_dl) in [list, tuple]: + self._validation_names = ['val_{}_'.format(idx) for idx in range(len(self._validation_dl))] + + def setup_multiple_test_data(self, test_data_config: Union[DictConfig, Dict]): + """ + (Optionally) Setups data loader to be used in test, with support for multiple data loaders. + + Args: + test_data_layer_config: test data layer parameters. + """ + # Set some placeholder overriden by helper method + self._test_dl_idx = 0 + self._test_names = None + self._test_dl = None # type: torch.utils.data.DataLoader + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + try: + self._multi_dataset_mode = True + model_utils.resolve_test_dataloaders(model=self) + finally: + self._multi_dataset_mode = False + + if self._test_names is None: + if self._test_dl is not None and type(self._test_dl) in [list, tuple]: + self._test_names = ['test_{}_'.format(idx) for idx in range(len(self._test_dl))] + + def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None): + """ + Prepares an optimizer from a string name and its optional config parameters. + + Args: + optim_config: A dictionary containing the following keys: + + * "lr": mandatory key for learning rate. Will raise ValueError if not provided. + * "optimizer": string name pointing to one of the available optimizers in the registry. \ + If not provided, defaults to "adam". + * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \ + The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \ + will be built and supplied to instantiate the optimizer. + """ + # If config was not explicitly passed to us + if optim_config is None: + # See if internal config has `optim` namespace + if self._cfg is not None and hasattr(self._cfg, 'optim'): + optim_config = self._cfg.optim + + # If config is still None, or internal config has no Optim, return without instantiation + if optim_config is None: + logging.info('No optimizer config provided, therefore no optimizer was created') + return + + else: + # Preserve the configuration + if not isinstance(optim_config, DictConfig): + optim_config = OmegaConf.create(optim_config) + + # See if internal config has `optim` namespace before preservation + if self._cfg is not None and hasattr(self._cfg, 'optim'): + if self._cfg.optim is None: + self._cfg.optim = copy.deepcopy(optim_config) + else: + with open_dict(self._cfg.optim): + self._cfg.optim = copy.deepcopy(optim_config) + + # Setup optimizer and scheduler + if optim_config is not None and isinstance(optim_config, DictConfig): + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + if 'sched' in optim_config and self._trainer is not None: + if not isinstance(self._trainer.accumulate_grad_batches, int): + raise ValueError("We do not currently support gradient acculumation that is not an integer.") + if self._trainer.max_steps is None: + # Store information needed to calculate max_steps + optim_config['sched']['t_max_epochs'] = self._trainer.max_epochs + optim_config['sched']['t_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches + optim_config['sched']['t_limit_train_batches'] = self._trainer.limit_train_batches + if self._trainer.distributed_backend is None: + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus or 1 + elif self._trainer.distributed_backend == "ddp_cpu": + optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes + elif self._trainer.distributed_backend == "ddp": + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + else: + logging.warning( + f"The lightning trainer received accelerator: {self._trainer.distributed_backend}. We " + "recommend to use 'ddp' instead." + ) + optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes + else: + optim_config['sched']['max_steps'] = self._trainer.max_steps + + # Force into DictConfig from nested structure + optim_config = OmegaConf.create(optim_config) + # Get back nested dict so we its mutable + optim_config = OmegaConf.to_container(optim_config, resolve=True) + + # Extract scheduler config if inside optimizer config + if 'sched' in optim_config: + scheduler_config = optim_config.pop('sched') + else: + scheduler_config = None + + # Check if caller provided optimizer name, default to Adam otherwise + optimizer_cls = optim_config.get('_target_', None) + + if optimizer_cls is None: + # Try to get optimizer name for dynamic resolution, defaulting to Adam + optimizer_name = optim_config.get('name', 'adam') + else: + if inspect.isclass(optimizer_cls): + optimizer_name = optimizer_cls.__name__.lower() + else: + # resolve the class name (lowercase) from the class path if not provided + optimizer_name = optimizer_cls.split(".")[-1].lower() + + # We are guarenteed to have lr since it is required by the argparser + # But maybe user forgot to pass it to this function + lr = optim_config.get('lr', None) + + # Check if caller has optimizer kwargs, default to empty dictionary + if 'args' in optim_config: + optimizer_args = optim_config.pop('args') + optimizer_args = optim.parse_optimizer_args(optimizer_name, optimizer_args) + else: + optimizer_args = copy.deepcopy(optim_config) + + # Remove extra parameters from optimizer_args nest + # Assume all other parameters are to be passed into optimizer constructor + optimizer_args.pop('name', None) + optimizer_args.pop('cls', None) + optimizer_args.pop('lr', None) + + # Adaptive schedulers don't need `lr` + if lr is not None: + optimizer_args['lr'] = lr + + # Actually instantiate the optimizer + if optimizer_cls is not None: + if inspect.isclass(optimizer_cls): + optimizer = optimizer_cls(self.optim_param_groups(), **optimizer_args) + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + else: + # Attempt class path resolution + try: + optimizer_cls = OmegaConf.create({'_target_': optimizer_cls}) + if lr is not None: + optimizer_config = {'lr': lr} + else: + optimizer_config = {} + optimizer_config.update(optimizer_args) + + optimizer_instance = hydra.utils.instantiate( + optimizer_cls, self.optim_param_groups(), **optimizer_config + ) # type: DictConfig + + logging.info("Optimizer config = %s", str(optimizer_instance)) + + self._optimizer = optimizer_instance + + except Exception as e: + logging.error( + "Could not instantiate class path - {} with kwargs {}".format( + optimizer_cls, str(optimizer_config) + ) + ) + raise e + + else: + optimizer = optim.get_optimizer(optimizer_name) + optimizer = optimizer(self.optim_param_groups(), **optimizer_args) + + logging.info("Optimizer config = %s", str(optimizer)) + + self._optimizer = optimizer + + # Try to instantiate scheduler for optimizer + self._scheduler = prepare_lr_scheduler( + optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl + ) + + # Return the optimizer with/without scheduler + # This return allows multiple optimizers or schedulers to be created + return self._optimizer, self._scheduler + + def optim_param_groups(self): + return self.parameters() + + def configure_optimizers(self): + self.setup_optimization() + + if self._scheduler is None: + return self._optimizer + else: + return [self._optimizer], [self._scheduler] + + def train_dataloader(self): + if self._train_dl is not None: + return self._train_dl + + def val_dataloader(self): + if self._validation_dl is not None: + return self._validation_dl + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def validation_epoch_end( + self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]] + ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]: + """ + Default DataLoader for Validation set which automatically supports multiple data loaders + via `multi_validation_epoch_end`. + + If multi dataset support is not required, override this method entirely in base class. + In such a case, there is no need to implement `multi_validation_epoch_end` either. + + .. note:: + If more than one data loader exists, and they all provide `val_loss`, + only the `val_loss` of the first data loader will be used by default. + This default can be changed by passing the special key `val_dl_idx: int` + inside the `validation_ds` config. + + Args: + outputs: Single or nested list of tensor outputs from one or more data loaders. + + Returns: + A dictionary containing the union of all items from individual data_loaders, + along with merged logs from all data loaders. + """ + # Case where we dont provide data loaders + if outputs is not None and len(outputs) == 0: + return {} + + # Case where we provide exactly 1 data loader + if type(outputs[0]) == dict: + output_dict = self.multi_validation_epoch_end(outputs, dataloader_idx=0) + + if output_dict is not None and 'log' in output_dict: + self.log_dict(output_dict.pop('log'), on_epoch=True) + + return output_dict + + else: # Case where we provide more than 1 data loader + output_dict = {'log': {}} + + # The output is a list of list of dicts, outer list corresponds to dataloader idx + for dataloader_idx, val_outputs in enumerate(outputs): + # Get prefix and dispatch call to multi epoch end + dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx) + dataloader_logs = self.multi_validation_epoch_end(val_outputs, dataloader_idx=dataloader_idx) + + # If result was not provided, generate empty dict + dataloader_logs = dataloader_logs or {} + + # Perform `val_loss` resolution first (if provided outside logs) + if 'val_loss' in dataloader_logs: + if 'val_loss' not in output_dict and dataloader_idx == self._val_dl_idx: + output_dict['val_loss'] = dataloader_logs['val_loss'] + + # For every item in the result dictionary + for k, v in dataloader_logs.items(): + # If the key is `log` + if k == 'log': + # Parse every element of the log, and attach the prefix name of the data loader + log_dict = {} + + for k_log, v_log in v.items(): + # If we are logging the metric, but dont provide it at result level, + # store it twice - once in log and once in result level. + # Also mark log with prefix name to avoid log level clash with other data loaders + if k_log not in output_dict['log'] and dataloader_idx == self._val_dl_idx: + new_k_log = k_log + + # Also insert duplicate key with prefix for ease of comparison / avoid name clash + log_dict[dataloader_prefix + k_log] = v_log + + else: + # Simply prepend prefix to key and save + new_k_log = dataloader_prefix + k_log + + # Store log value + log_dict[new_k_log] = v_log + + # Update log storage of individual data loader + output_logs = output_dict['log'] + output_logs.update(log_dict) + + # Update global log storage + output_dict['log'] = output_logs + + else: + # If any values are stored outside 'log', simply prefix name and store + new_k = dataloader_prefix + k + output_dict[new_k] = v + + if 'log' in output_dict: + self.log_dict(output_dict.pop('log'), on_epoch=True) + + # return everything else + return output_dict + + def test_epoch_end( + self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]] + ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]: + """ + Default DataLoader for Test set which automatically supports multiple data loaders + via `multi_test_epoch_end`. + + If multi dataset support is not required, override this method entirely in base class. + In such a case, there is no need to implement `multi_test_epoch_end` either. + + .. note:: + If more than one data loader exists, and they all provide `test_loss`, + only the `test_loss` of the first data loader will be used by default. + This default can be changed by passing the special key `test_dl_idx: int` + inside the `test_ds` config. + + Args: + outputs: Single or nested list of tensor outputs from one or more data loaders. + + Returns: + A dictionary containing the union of all items from individual data_loaders, + along with merged logs from all data loaders. + """ + # Case where we dont provide data loaders + if outputs is not None and len(outputs) == 0: + return {} + + # Case where we provide exactly 1 data loader + if type(outputs[0]) == dict: + output_dict = self.multi_test_epoch_end(outputs, dataloader_idx=0) + + if output_dict is not None and 'log' in output_dict: + self.log_dict(output_dict.pop('log'), on_epoch=True) + + return output_dict + + else: # Case where we provide more than 1 data loader + output_dict = {'log': {}} + + # The output is a list of list of dicts, outer list corresponds to dataloader idx + for dataloader_idx, test_outputs in enumerate(outputs): + # Get prefix and dispatch call to multi epoch end + dataloader_prefix = self.get_test_dataloader_prefix(dataloader_idx) + dataloader_logs = self.multi_test_epoch_end(test_outputs, dataloader_idx=dataloader_idx) + + # If result was not provided, generate empty dict + dataloader_logs = dataloader_logs or {} + + # Perform `test_loss` resolution first (if provided outside logs) + if 'test_loss' in dataloader_logs: + if 'test_loss' not in output_dict and dataloader_idx == self._test_dl_idx: + output_dict['test_loss'] = dataloader_logs['test_loss'] + + # For every item in the result dictionary + for k, v in dataloader_logs.items(): + # If the key is `log` + if k == 'log': + # Parse every element of the log, and attach the prefix name of the data loader + log_dict = {} + for k_log, v_log in v.items(): + # If we are logging the loss, but dont provide it at result level, + # store it twice - once in log and once in result level. + # Also mark log with prefix name to avoid log level clash with other data loaders + if k_log not in output_dict['log'] and dataloader_idx == self._test_dl_idx: + new_k_log = k_log + + # Also insert duplicate key with prefix for ease of comparison / avoid name clash + log_dict[dataloader_prefix + k_log] = v_log + + else: + # Simply prepend prefix to key and save + new_k_log = dataloader_prefix + k_log + + log_dict[new_k_log] = v_log + + # Update log storage of individual data loader + output_logs = output_dict.get('log', {}) + output_logs.update(log_dict) + + # Update global log storage + output_dict['log'] = output_logs + + else: + # If any values are stored outside 'log', simply prefix name and store + new_k = dataloader_prefix + k + output_dict[new_k] = v + + if 'log' in output_dict: + self.log_dict(output_dict.pop('log'), on_epoch=True) + + # return everything else + return output_dict + + def multi_validation_epoch_end( + self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0 + ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]: + logging.warning( + "Multi data loader support has been enabled, but " + "`multi_validation_epoch_end(outputs, dataloader_idx) has not been implemented.\n" + "If you require multi data loader support for validation sets, please override this method.\n" + "If you do not require multi data loader support, please instead override " + "`validation_epoch_end(outputs)." + ) + + def multi_test_epoch_end( + self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0 + ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]: + logging.warning( + "Multi data loader support has been enabled, but " + "`multi_test_epoch_end(outputs, dataloader_idx) has not been implemented.\n" + "If you require multi data loader support for validation sets, please override this method.\n" + "If you do not require multi data loader support, please instead override " + "`test_epoch_end(outputs)." + ) + + def get_validation_dataloader_prefix(self, dataloader_idx: int = 0) -> str: + """ + Get the name of one or more data loaders, which will be prepended to all logs. + + Args: + dataloader_idx: Index of the data loader. + + Returns: + str name of the data loader at index provided. + """ + return self._validation_names[dataloader_idx] + + def get_test_dataloader_prefix(self, dataloader_idx: int = 0) -> str: + """ + Get the name of one or more data loaders, which will be prepended to all logs. + + Args: + dataloader_idx: Index of the data loader. + + Returns: + str name of the data loader at index provided. + """ + return self._test_names[dataloader_idx] + + def teardown(self, stage: str): + """ + Called at the end of fit and test. + + Args: + stage: either 'fit' or 'test' + """ + if stage == 'fit': + # Update env variable to bypass multi gpu issue after training + # This fix affects usage of trainer.test() after trainer.train() + # If trainer.train() was done on multiple GPUs, then trainer.test() + # will try to do ddp, even if its a new Trainer object with just 1 GPU. + # Temporary patch to fix that + if 'PL_TRAINER_GPUS' in os.environ: + os.environ.pop('PL_TRAINER_GPUS') + + super().teardown(stage) + + def prepare_test(self, trainer: 'Trainer') -> bool: + """ + Helper method to check whether the model can safely be tested + on a dataset after training (or loading a checkpoint). + + # Usage: + trainer = Trainer() + if model.prepare_test(trainer): + trainer.test(model) + + Returns: + bool which declares the model safe to test. Provides warnings if it has to + return False to guide the user. + """ + if not hasattr(self._cfg, 'test_ds'): + logging.info("No `test_ds` config found within the manifest.") + return False + + # Replace ddp multi-gpu until PTL has a fix + DDP_WARN = """\n\nDuring testing, it is currently advisable to construct a new Trainer " + "with single GPU and no DDP to obtain accurate results. + "Following pattern should be used: " + "gpu = 1 if cfg.trainer.gpus != 0 else 0" + "trainer = Trainer(gpus=gpu)" + "if model.prepare_test(trainer):" + " trainer.test(model)\n\n""" + + if trainer is not None: + if trainer.num_gpus > 1: + logging.warning(DDP_WARN) + return False + + # Assign trainer to the model + self.set_trainer(trainer) + return True + + def set_trainer(self, trainer: Trainer): + """ + Set an instance of Trainer object. + + Args: + trainer: PyTorch Lightning Trainer object. + """ + self._trainer = trainer + self.set_world_size(self._trainer) + + def set_world_size(self, trainer: Trainer): + """ + Determines the world size from the PyTorch Lightning Trainer. + And then updates AppState. + + Args: + trainer (Trainer): PyTorch Lightning Trainer object + """ + # Update AppState with world information from trainer + if isinstance(trainer, Trainer): + app_state = AppState() + if self._trainer.num_gpus and self._trainer.num_nodes: + app_state.world_size = self._trainer.num_gpus * self._trainer.num_nodes + else: + logging.warning(f'World size can only be set by PyTorch Lightning Trainer.') + + def _update_dataset_config(self, dataset_name: str, config: Optional[Union[DictConfig, Dict]]): + """ + Update the config (if not None) of the dataset by given name. + Preserves said config after updating. + + Args: + dataset_name: str name of the dataset whose config is being updated. + Can be one of `train`, `validation` and `test`. + config: Optional DictConfig or dict. If None is passed, this method simply returns. + If dict is passed, it is cast into a DictConfig. + The internal config is updated with the passed config. + """ + if hasattr(self, '_multi_dataset_mode') and self._multi_dataset_mode is True: + return + + if config is not None: + if not isinstance(config, DictConfig): + config = OmegaConf.create(config) + + if dataset_name in ['train', 'validation', 'test']: + OmegaConf.set_struct(self.cfg, False) + + key_name = dataset_name + "_ds" + self.cfg[key_name] = config + + OmegaConf.set_struct(self.cfg, True) + + # Update hyper parameters by calling property setter + self.cfg = self._cfg + else: + raise ValueError("`dataset_name` when updating config must be one of [train, validation, test]") + + @property + def num_weights(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + @property + def cfg(self): + return self._cfg + + @cfg.setter + def cfg(self, cfg): + self._cfg = cfg + self._set_hparams(cfg) + + @staticmethod + def __make_nemo_file_from_folder(filename, source_dir): + with tarfile.open(filename, "w:gz") as tar: + # tar.add(source_dir, arcname=path.basename(source_dir)) + tar.add(source_dir, arcname="./") + + @staticmethod + def __unpack_nemo_file(path2file: str, out_folder: str) -> str: + if not path.exists(path2file): + raise FileNotFoundError(f"{path2file} does not exist") + tar = tarfile.open(path2file, "r:gz") + tar.extractall(path=out_folder) + tar.close() + return out_folder + + @staticmethod + def _is_model_being_restored() -> bool: + global _MODEL_IS_RESTORED + return _MODEL_IS_RESTORED + + @staticmethod + def _set_model_restore_state(is_being_restored: bool): + global _MODEL_IS_RESTORED + _MODEL_IS_RESTORED = is_being_restored + + @staticmethod + def _is_restore_type_tarfile() -> bool: + """ + Utility method that checks if the restore path of the underlying Model + is a tarfile (can be any valid archive)._MODEL_EFF_SAVE + """ + global _MODEL_RESTORE_PATH + + if _MODEL_RESTORE_PATH is None: + return False + else: + if tarfile.is_tarfile(_MODEL_RESTORE_PATH): + return True + else: + return False + + @staticmethod + def set_eff_save(use_eff_save: bool): + global _MODEL_EFF_SAVE + _MODEL_EFF_SAVE = use_eff_save + + @staticmethod + def use_eff_save() -> bool: + global _MODEL_EFF_SAVE + return _MODEL_EFF_SAVE diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/module.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/module.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e1ace2f483c1a7bec504c63d9d5854bb0e5231 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/module.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager + +from torch.nn import Module + +from nemo.core.classes.common import FileIO, Serialization, Typing + +__all__ = ['NeuralModule'] + + +class NeuralModule(Module, Typing, Serialization, FileIO): + """ + Abstract class offering interface shared between all PyTorch Neural Modules. + """ + + @property + def num_weights(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + def input_example(self): + """ + Override this method if random inputs won't work + Returns: + A tuple sample of valid input data. + """ + + return + + def freeze(self): + r""" + Freeze all params for inference. + """ + requires_grad_states = [] + for param in self.parameters(): + requires_grad_states.append(param.requires_grad) + param.requires_grad = False + + self.eval() + return requires_grad_states + + def unfreeze(self, requires_grad_states) -> None: + """ + Unfreeze all parameters for training. + """ + for i, param in enumerate(self.parameters()): + param.requires_grad = requires_grad_states[i] + + self.train() + + @contextmanager + def as_frozen(self): + """ + Context manager which temporarily freezes a module, yields control and finally unfreezes the module. + """ + requires_grad_states = self.freeze() + + try: + yield + finally: + self.unfreeze(requires_grad_states) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3470ed2eb99a28343bb806c0247caa15d28b294 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/__init__.py @@ -0,0 +1,47 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.core.config.base_config import Config +from nemo.core.config.optimizers import ( + AdadeltaParams, + AdagradParams, + AdamaxParams, + AdamParams, + AdamWParams, + NovogradParams, + OptimizerParams, + RMSpropParams, + RpropParams, + SGDParams, + get_optimizer_config, + register_optimizer_params, +) +from nemo.core.config.pytorch import DataLoaderConfig +from nemo.core.config.pytorch_lightning import TrainerConfig, HTrainerConfig +from nemo.core.config.schedulers import ( + CosineAnnealingParams, + InverseSquareRootAnnealingParams, + NoamAnnealingParams, + PolynomialDecayAnnealingParams, + PolynomialHoldDecayAnnealingParams, + SchedulerParams, + SquareAnnealingParams, + SquareRootAnnealingParams, + WarmupAnnealingParams, + WarmupHoldSchedulerParams, + WarmupSchedulerParams, + get_scheduler_config, + register_scheduler_params, +) +from nemo.core.config.set_config import hydra_runner, set_config diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/base_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/base_config.py new file mode 100644 index 0000000000000000000000000000000000000000..44649ef7b4fb460f677e3d1e37b7e3a9d459e34d --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/base_config.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +__all__ = ['Config'] + + +@dataclass +class Config: + """ + Abstract NeMo Configuration class. + + Args: + name: name of the module/dataset/loss/model object (used in serialization, DEFAULT: None) + """ + + name: Optional[str] = None diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/modelPT.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/modelPT.py new file mode 100644 index 0000000000000000000000000000000000000000..c729f29663c3d1da6e4ffd6d55ca31fad665066b --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/modelPT.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +from omegaconf import MISSING + +from nemo.core import config +from nemo.core.classes.dataset import DatasetConfig +from nemo.utils import exp_manager + + +@dataclass +class SchedConfig: + name: str = MISSING + min_lr: float = 0.0 + last_epoch: int = -1 + + +@dataclass +class OptimConfig: + name: str = MISSING + lr: float = MISSING + sched: Optional[SchedConfig] = None + + +@dataclass +class ModelConfig: + """ + Model component inside ModelPT + """ + + # ... + train_ds: Optional[DatasetConfig] = None + validation_ds: Optional[DatasetConfig] = None + test_ds: Optional[DatasetConfig] = None + optim: Optional[OptimConfig] = None + + +@dataclass +class HydraConfig: + run: Dict[str, Any] = field(default_factory=lambda: {"dir": "."}) + job_logging: Dict[str, Any] = field(default_factory=lambda: {"root": {"handlers": None}}) + + +@dataclass +class ModelPTConfig: + name: str = MISSING + model: ModelConfig = MISSING + trainer: config.TrainerConfig = config.TrainerConfig( + accelerator="ddp", checkpoint_callback=False, logger=False, log_every_n_steps=1 + ) + exp_manager: Optional[Any] = exp_manager.ExpManagerConfig() + hydra: HydraConfig = HydraConfig() diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/optimizers.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..18ace2dc0fc67eeabc334a3d935d88bee11d7d79 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/optimizers.py @@ -0,0 +1,263 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, Optional, Tuple + +from omegaconf import OmegaConf + +__all__ = [ + 'OptimizerParams', + 'AdamParams', + 'NovogradParams', + 'SGDParams', + 'AdadeltaParams', + 'AdamaxParams', + 'AdagradParams', + 'AdamWParams', + 'RMSpropParams', + 'RpropParams', +] + + +@dataclass +class OptimizerParams: + """ + Base Optimizer params with no values. User can chose it to explicitly override via + command line arguments + """ + + +@dataclass +class SGDParams(OptimizerParams): + """ + Default configuration for Adam optimizer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD + """ + + momentum: float = 0 + dampening: float = 0 + weight_decay: float = 0 + nesterov: bool = False + + +@dataclass +class AdamParams(OptimizerParams): + """ + Default configuration for Adam optimizer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html?highlight=adam#torch.optim.Adam + """ + + # betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-08 + weight_decay: float = 0 + amsgrad: bool = False + + +@dataclass +class AdamWParams(OptimizerParams): + """ + Default configuration for AdamW optimizer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW + """ + + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-08 + weight_decay: float = 0 + amsgrad: bool = False + + +@dataclass +class AdadeltaParams(OptimizerParams): + """ + Default configuration for Adadelta optimizer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.Adadelta + """ + + rho: float = 0.9 + eps: float = 1e-6 + weight_decay: float = 0 + + +@dataclass +class AdamaxParams(OptimizerParams): + """ + Default configuration for Adamax optimizer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.Adamax + """ + + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 0 + + +@dataclass +class AdagradParams(OptimizerParams): + """ + Default configuration for Adagrad optimizer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.Adagrad + """ + + lr_decay: float = 0 + weight_decay: float = 0 + initial_accumulator_value: float = 0 + eps: float = 1e-10 + + +@dataclass +class RMSpropParams(OptimizerParams): + """ + Default configuration for RMSprop optimizer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.RMSprop + """ + + alpha: float = 0.99 + eps: float = 1e-8 + weight_decay: float = 0 + momentum: float = 0 + centered: bool = False + + +@dataclass +class RpropParams(OptimizerParams): + """ + Default configuration for RpropParams optimizer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/optim.html#torch.optim.Rprop + """ + + etas: Tuple[float, float] = (0.5, 1.2) + step_sizes: Tuple[float, float] = (1e-6, 50) + + +@dataclass +class NovogradParams(OptimizerParams): + """ + Configuration of the Novograd optimizer. + + It has been proposed in "Stochastic Gradient Methods with Layer-wise + Adaptive Moments for Training of Deep Networks" + (https://arxiv.org/abs/1905.11286) + + Args: + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and Beyond" + """ + + betas: Tuple[float, float] = (0.95, 0.98) + eps: float = 1e-8 + weight_decay: float = 0 + grad_averaging: bool = False + amsgrad: bool = False + luc: bool = False + luc_trust: float = 1e-3 + luc_eps: float = 1e-8 + + +def register_optimizer_params(name: str, optimizer_params: OptimizerParams): + """ + Checks if the optimizer param name exists in the registry, and if it doesnt, adds it. + + This allows custom optimizer params to be added and called by name during instantiation. + + Args: + name: Name of the optimizer. Will be used as key to retrieve the optimizer. + optimizer_params: Optimizer class + """ + if name in AVAILABLE_OPTIMIZER_PARAMS: + raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}") + + AVAILABLE_OPTIMIZER_PARAMS[name] = optimizer_params + + +def get_optimizer_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> OptimizerParams: + """ + Convenience method to obtain a OptimizerParams class and partially instantiate it with optimizer kwargs. + + Args: + name: Name of the OptimizerParams in the registry. + kwargs: Optional kwargs of the optimizer used during instantiation. + + Returns: + a partially instantiated OptimizerParams + """ + if name is None: + return kwargs + + if name not in AVAILABLE_OPTIMIZER_PARAMS: + raise ValueError( + f"Cannot resolve optimizer parameters '{name}'. Available optimizer parameters are : " + f"{AVAILABLE_OPTIMIZER_PARAMS.keys()}" + ) + + scheduler_params = AVAILABLE_OPTIMIZER_PARAMS[name] + + if kwargs is not None and len(kwargs) != 0: + kwargs = OmegaConf.create(kwargs) + OmegaConf.merge(scheduler_params(), kwargs) + + scheduler_params = partial(scheduler_params, **kwargs) + return scheduler_params + + +AVAILABLE_OPTIMIZER_PARAMS = { + 'optim_params': OptimizerParams, + 'adam_params': AdamParams, + 'novograd_params': NovogradParams, + 'sgd_params': SGDParams, + 'adadelta_params': AdadeltaParams, + 'adamax_params': AdamaxParams, + 'adagrad_params': AdagradParams, + 'adamw_params': AdamWParams, + 'rmsprop_params': RMSpropParams, + 'rprop_params': RpropParams, +} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..944df02ed802e3a91281c02a5082f522e22e044a --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch.py @@ -0,0 +1,45 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Optional + +from omegaconf import MISSING + +__all__ = ['DataLoaderConfig'] + + +@dataclass +class DataLoaderConfig: + """ + Configuration of PyTorch DataLoader. + + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader + """ + + batch_size: int = MISSING + shuffle: bool = False + sampler: Optional[Any] = None + batch_sampler: Optional[Any] = None + num_workers: int = 0 + collate_fn: Optional[Any] = None + pin_memory: bool = False + drop_last: bool = False + timeout: int = 0 + worker_init_fn: Optional[Any] = None + multiprocessing_context: Optional[Any] = None diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch_lightning.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch_lightning.py new file mode 100644 index 0000000000000000000000000000000000000000..8b248959bc7446a1925ef7c4cdf6367581d60035 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch_lightning.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +from hydra.core.config_store import ConfigStore + +__all__ = ['TrainerConfig'] + + +cs = ConfigStore.instance() + + +@dataclass +class TrainerConfig: + """ + Configuration of PyTorch Lightning Trainer. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + ..warning: + Picked just few params of the PTL trainer for now. This needs to be discussed. + ..note: + For the details on the function/meanings of the arguments, please refer to: + https://pytorch-lightning.readthedocs.io/en/latest/trainer.html# + """ + + logger: Any = True + checkpoint_callback: Any = True + callbacks: Optional[Any] = None + default_root_dir: Optional[str] = None + gradient_clip_val: float = 0 + process_position: int = 0 + num_nodes: int = 1 + num_processes: int = 1 + gpus: Optional[Any] = None + auto_select_gpus: bool = False + tpu_cores: Optional[Any] = None + log_gpu_memory: Optional[str] = None + progress_bar_refresh_rate: int = 1 + overfit_batches: Any = 0.0 + track_grad_norm: Any = -1 + check_val_every_n_epoch: int = 1 + fast_dev_run: bool = False + accumulate_grad_batches: Any = 1 + max_epochs: int = 1000 + min_epochs: int = 1 + max_steps: Optional[int] = None + min_steps: Optional[int] = None + limit_train_batches: Any = 1.0 + limit_val_batches: Any = 1.0 + limit_test_batches: Any = 1.0 + val_check_interval: Any = 1.0 + flush_logs_every_n_steps: int = 100 + log_every_n_steps: int = 50 + accelerator: Optional[str] = None + sync_batchnorm: bool = False + precision: int = 32 + weights_summary: Optional[str] = "full" # ModelSummary.MODE_DEFAULT + weights_save_path: Optional[str] = None + num_sanity_val_steps: int = 2 + truncated_bptt_steps: Optional[int] = None + resume_from_checkpoint: Optional[str] = None + profiler: Optional[Any] = None + benchmark: bool = False + deterministic: bool = False + reload_dataloaders_every_epoch: bool = False + auto_lr_find: Any = False + replace_sampler_ddp: bool = True + terminate_on_nan: bool = False + auto_scale_batch_size: Any = False + prepare_data_per_node: bool = True + amp_backend: str = 'native' + amp_level: str = 'O2' # backward compatible, todo: remove in v1.0.0 + + +@dataclass +class HTrainerConfig(TrainerConfig): + find_unused_parameters: bool = False + ema_decay: float = 0.0 + ema_start_step: int = 0 + + +# Register the trainer config. +cs.store( + group="trainer", name="trainer", node=TrainerConfig, +) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/schedulers.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..629a1fb9fcef93c7bbe7e7d455048db85f030738 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/schedulers.py @@ -0,0 +1,250 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, Optional + + +@dataclass +class SchedulerParams: + """ + Base configuration for all schedulers. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + last_epoch: int = -1 + + +@dataclass +class WarmupSchedulerParams(SchedulerParams): + """ + Base configuration for all schedulers. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + warmup_steps: Optional[float] = None + warmup_ratio: Optional[float] = None + + +@dataclass +class WarmupHoldSchedulerParams(WarmupSchedulerParams): + """ + Base configuration for all schedulers. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + hold_steps: Optional[float] = None + hold_ratio: Optional[float] = None + min_lr: float = 0.0 + + +@dataclass +class SquareAnnealingParams(WarmupSchedulerParams): + """ + Square Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + min_lr: float = 1e-5 + + +@dataclass +class SquareRootAnnealingParams(WarmupSchedulerParams): + """ + Square Root Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + min_lr: float = 0.0 + + +@dataclass +class CosineAnnealingParams(WarmupSchedulerParams): + """ + Cosine Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + min_lr: float = 0.0 + + +@dataclass +class NoamAnnealingParams(WarmupSchedulerParams): + """ + Cosine Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + min_lr: float = 0.0 + + +@dataclass +class WarmupAnnealingParams(WarmupSchedulerParams): + """ + Warmup Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + warmup_ratio: 0.0 + + +@dataclass +class InverseSquareRootAnnealingParams(WarmupSchedulerParams): + """ + Inverse Square Root Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + +@dataclass +class PolynomialDecayAnnealingParams(WarmupSchedulerParams): + """ + Polynomial Decay Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + power: float = 1.0 + cycle: bool = False + + +@dataclass +class PolynomialHoldDecayAnnealingParams(WarmupSchedulerParams): + """ + Polynomial Hold Decay Annealing parameter config + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + power: float = 1.0 + cycle: bool = False + + +""" +Pytorch Optimizers +""" + + +@dataclass +class StepLRParams(SchedulerParams): + """ + Config for StepLR. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + step_size: float = 0.1 + gamma: float = 0.1 + + +@dataclass +class ExponentialLRParams(SchedulerParams): + """ + Config for ExponentialLR. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + gamma: float = 0.9 + + +@dataclass +class ReduceLROnPlateauParams: + """ + Config for ReduceLROnPlateau. + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + mode: str = 'min' + factor: float = 0.1 + patience: int = 10 + verbose: bool = False + threshold: float = 1e-4 + threshold_mode: str = 'rel' + cooldown: int = 0 + min_lr: float = 0 + eps: float = 1e-8 + + +@dataclass +class CyclicLRParams(SchedulerParams): + """ + Config for CyclicLR. + NOTE: + # `scale_fn` is not supported + + It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). + """ + + base_lr: float = 0.001 + max_lr: float = 0.1 + step_size_up: int = 2000 + step_size_down: Optional[int] = None + mode: str = 'triangular' + gamma: float = 1.0 + scale_mode: str = 'cycle' + # scale_fn is not supported + cycle_momentum: bool = True + base_momentum: float = 0.8 + max_momentum: float = 0.9 + + +def register_scheduler_params(name: str, scheduler_params: SchedulerParams): + """ + Checks if the schduler config name exists in the registry, and if it doesnt, adds it. + + This allows custom schedulers to be added and called by name during instantiation. + + Args: + name: Name of the optimizer. Will be used as key to retrieve the optimizer. + scheduler_params: SchedulerParams class + """ + if name in AVAILABLE_SCHEDULER_PARAMS: + raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}") + + AVAILABLE_SCHEDULER_PARAMS[name] = scheduler_params + + +def get_scheduler_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> SchedulerParams: + """ + Convenience method to obtain a SchedulerParams class and partially instantiate it with optimizer kwargs. + + Args: + name: Name of the SchedulerParams in the registry. + kwargs: Optional kwargs of the optimizer used during instantiation. + + Returns: + a partially instantiated SchedulerParams + """ + if name not in AVAILABLE_SCHEDULER_PARAMS: + raise ValueError( + f"Cannot resolve scheduler parameters '{name}'. Available scheduler parameters are : " + f"{AVAILABLE_SCHEDULER_PARAMS.keys()}" + ) + + scheduler_params = AVAILABLE_SCHEDULER_PARAMS[name] + scheduler_params = partial(scheduler_params, **kwargs) + return scheduler_params + + +AVAILABLE_SCHEDULER_PARAMS = { + 'SchedulerParams': SchedulerParams, + 'WarmupPolicyParams': WarmupSchedulerParams, + 'WarmupHoldPolicyParams': WarmupHoldSchedulerParams, + 'SquareAnnealingParams': SquareAnnealingParams, + 'SquareRootAnnealingParams': SquareRootAnnealingParams, + 'InverseSquareRootAnnealingParams': InverseSquareRootAnnealingParams, + 'CosineAnnealingParams': CosineAnnealingParams, + 'NoamAnnealingParams': NoamAnnealingParams, + 'WarmupAnnealingParams': WarmupAnnealingParams, + 'PolynomialDecayAnnealingParams': PolynomialDecayAnnealingParams, + 'PolynomialHoldDecayAnnealingParams': PolynomialHoldDecayAnnealingParams, +} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/set_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/set_config.py new file mode 100644 index 0000000000000000000000000000000000000000..adc8f9c1cc2085c0bf465f253675f85f9094c434 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/set_config.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Callable, Optional + +from hydra._internal.utils import _run_hydra, get_args_parser +from hydra.core.config_store import ConfigStore +from hydra.types import TaskFunction +from omegaconf import DictConfig + +from nemo.core.config import Config + + +def hydra_runner( + config_path: Optional[str] = None, config_name: Optional[str] = None, schema: Optional[Any] = None +) -> Callable[[TaskFunction], Any]: + """ + Decorator used for passing the Config paths to main function. + Optionally registers a schema used for validation/providing default values. + + Args: + config_path: Path to the directory where the config exists. + config_name: Name of the config file. + schema: Structured config type representing the schema used for validation/providing default values. + """ + if schema is not None: + # Create config store. + cs = ConfigStore.instance() + # Register the configuration as a node under a given name. + cs.store(name=config_name.replace(".yaml", ""), node=schema) + + def decorator(task_function: TaskFunction) -> Callable[[], None]: + @functools.wraps(task_function) + def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: + # Check it config was passed. + if cfg_passthrough is not None: + return task_function(cfg_passthrough) + else: + args = get_args_parser() + + # Parse arguments in order to retrieve overrides + parsed_args = args.parse_args() + + # Get overriding args in dot string format + overrides = parsed_args.overrides # type: list + + # Update overrides + overrides.append("hydra.run.dir=.") + overrides.append('hydra.job_logging.root.handlers=null') + + # Wrap a callable object with name `parse_args` + # This is to mimic the ArgParser.parse_args() API. + class _argparse_wrapper: + def __init__(self, arg_parser): + self.arg_parser = arg_parser + self._actions = arg_parser._actions + + def parse_args(self, args=None, namespace=None): + return parsed_args + + # no return value from run_hydra() as it may sometime actually run the task_function + # multiple times (--multirun) + _run_hydra( + args_parser=_argparse_wrapper(args), + task_function=task_function, + config_path=config_path, + config_name=config_name, + strict=None, + ) + + return wrapper + + return decorator + + +def set_config(config: Config) -> Callable[[TaskFunction], Any]: + """ + Decorator used for passing the Structured Configs to main function. + + Args: + config: config class derived from Config. + """ + # Get class name. Not sure how important this is, but coudn't get name by accessing type().__name__. + class_name = str(config) + # Create config store. + cs = ConfigStore.instance() + # Register the configuration as a node under a given name. + cs.store(name=class_name, node=config) + + def decorator(task_function: TaskFunction) -> Callable[[], None]: + @functools.wraps(task_function) + def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: + # Check it config was passed. + if cfg_passthrough is not None: + return task_function(cfg_passthrough) + else: + args = get_args_parser() + + # no return value from run_hydra() as it may sometime actually run the task_function + # multiple times (--multirun) + _run_hydra( + args_parser=args, + task_function=task_function, + config_path=None, + config_name=class_name, + strict=None, + ) + + return wrapper + + return decorator diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8598bbe67f232e73be763ac0a5a37c7f7e6a5260 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.core.neural_types.axes import * +from nemo.core.neural_types.comparison import * +from nemo.core.neural_types.elements import * +from nemo.core.neural_types.neural_type import * diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/axes.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/axes.py new file mode 100644 index 0000000000000000000000000000000000000000..cda781f021cc544572eec543d2828fd38c9bf0a6 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/axes.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Optional + +__all__ = ['AxisKindAbstract', 'AxisKind', 'AxisType'] + + +class AxisKindAbstract(Enum): + """This is an abstract Enum to represents what does varying axis dimension mean. + In practice, you will almost always use AxisKind Enum. This Enum should be inherited by + your OWN Enum if you aren't satisfied with AxisKind. Then your own Enum can be used + instead of AxisKind.""" + + pass + + +class AxisKind(AxisKindAbstract): + """This Enum represents what does varying axis dimension mean. + For example, does this dimension correspond to width, batch, time, etc. + The "Dimension" and "Channel" kinds are the same and used to represent + a general axis. "Any" axis will accept any axis kind fed to it. + """ + + Batch = 0 + Time = 1 + Dimension = 2 + Channel = 2 + Width = 3 + Height = 4 + Any = 5 + Sequence = 6 + FlowGroup = 7 + Singleton = 8 # Used to represent a axis that has size 1 + + def __repr__(self): + return self.__str__() + + def __str__(self): + return str(self.name).lower() + + @staticmethod + def from_str(label): + """Returns AxisKind instance based on short string representation""" + _label = label.lower().strip() + if _label == "b" or _label == "n" or _label == "batch": + return AxisKind.Batch + elif _label == "t" or _label == "time": + return AxisKind.Time + elif _label == "d" or _label == "c" or _label == "channel": + return AxisKind.Dimension + elif _label == "w" or _label == "width": + return AxisKind.Width + elif _label == "h" or _label == "height": + return AxisKind.Height + elif _label == "s" or _label == "singleton": + return AxisKind.Singleton + elif _label == "flowgroup": + return AxisKind.FlowGroup + elif _label == "any": + return AxisKind.Any + else: + raise ValueError(f"Can't create AxisKind from {label}") + + +class AxisType(object): + """This class represents axis semantics and (optionally) it's dimensionality + Args: + kind (AxisKindAbstract): what kind of axis it is? For example Batch, Height, etc. + size (int, optional): specify if the axis should have a fixed size. By default it is set to None and you + typically do not want to set it for Batch and Time + is_list (bool, default=False): whether this is a list or a tensor axis + """ + + def __init__(self, kind: AxisKindAbstract, size: Optional[int] = None, is_list=False): + if size is not None and is_list: + raise ValueError("The axis can't be list and have a fixed size") + self.kind = kind + self.size = size + self.is_list = is_list + + def __repr__(self): + if self.size is None: + representation = str(self.kind) + else: + representation = f"{str(self.kind)}:{self.size}" + if self.is_list: + representation += "_listdim" + return representation diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/comparison.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..391096221d478fd9717c019c3faf0d52ab6169da --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/comparison.py @@ -0,0 +1,32 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +__all__ = ['NeuralTypeComparisonResult'] + + +class NeuralTypeComparisonResult(Enum): + """The result of comparing two neural type objects for compatibility. + When comparing A.compare_to(B):""" + + SAME = 0 + LESS = 1 # A is B + GREATER = 2 # B is A + DIM_INCOMPATIBLE = 3 # Resize connector might fix incompatibility + TRANSPOSE_SAME = 4 # A transpose and/or converting between lists and tensors will make them same + CONTAINER_SIZE_MISMATCH = 5 # A and B contain different number of elements + INCOMPATIBLE = 6 # A and B are incompatible + SAME_TYPE_INCOMPATIBLE_PARAMS = 7 # A and B are of the same type but parametrized differently + UNCHECKED = 8 # type comparison wasn't done diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/elements.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/elements.py new file mode 100644 index 0000000000000000000000000000000000000000..8e94783851577e818ea1704502e23b3926b5afb0 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/elements.py @@ -0,0 +1,320 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from abc import ABC +from typing import Dict, Optional, Tuple + +from nemo.core.neural_types.comparison import NeuralTypeComparisonResult + +__all__ = [ + 'ElementType', + 'VoidType', + 'ChannelType', + 'AcousticEncodedRepresentation', + 'AudioSignal', + 'SpectrogramType', + 'MelSpectrogramType', + 'MFCCSpectrogramType', + 'LogitsType', + 'LabelsType', + 'HypothesisType', + 'LossType', + 'RegressionValuesType', + 'CategoricalValuesType', + 'PredictionsType', + 'LogprobsType', + 'LengthsType', + 'EmbeddedTextType', + 'EncodedRepresentation', + 'MaskType', + 'Target', + 'ClassificationTarget', + 'ImageFeatureValue', + 'Index', + 'ImageValue', + 'NormalizedImageValue', + 'StringLabel', + 'StringType', + 'TokenIndex', + 'Length', + 'IntType', + 'FloatType', + 'NormalDistributionSamplesType', + 'NormalDistributionMeanType', + 'NormalDistributionLogVarianceType', + 'TokenDurationType', + 'TokenLogDurationType', + 'LogDeterminantType', + 'SequenceToSequenceAlignmentType', +] + + +class ElementType(ABC): + """Abstract class defining semantics of the tensor elements. + We are relying on Python for inheritance checking""" + + def __str__(self): + return self.__doc__ + + def __repr__(self): + return self.__class__.__name__ + + @property + def type_parameters(self) -> Dict: + """Override this property to parametrize your type. For example, you can specify 'storage' type such as + float, int, bool with 'dtype' keyword. Another example, is if you want to represent a signal with a + particular property (say, sample frequency), then you can put sample_freq->value in there. + When two types are compared their type_parameters must match.""" + return {} + + @property + def fields(self) -> Optional[Tuple]: + """This should be used to logically represent tuples/structures. For example, if you want to represent a + bounding box (x, y, width, height) you can put a tuple with names ('x', y', 'w', 'h') in here. + Under the hood this should be converted to the last tesnor dimension of fixed size = len(fields). + When two types are compared their fields must match.""" + return None + + def compare(self, second) -> NeuralTypeComparisonResult: + # First, check general compatibility + first_t = type(self) + second_t = type(second) + + if first_t == second_t: + result = NeuralTypeComparisonResult.SAME + elif issubclass(first_t, second_t): + result = NeuralTypeComparisonResult.LESS + elif issubclass(second_t, first_t): + result = NeuralTypeComparisonResult.GREATER + else: + result = NeuralTypeComparisonResult.INCOMPATIBLE + + if result != NeuralTypeComparisonResult.SAME: + return result + else: + # now check that all parameters match + check_params = set(self.type_parameters.keys()) == set(second.type_parameters.keys()) + if check_params is False: + return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS + else: + for k1, v1 in self.type_parameters.items(): + if v1 is None or second.type_parameters[k1] is None: + # Treat None as Void + continue + if v1 != second.type_parameters[k1]: + return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS + # check that all fields match + if self.fields == second.fields: + return NeuralTypeComparisonResult.SAME + else: + return NeuralTypeComparisonResult.INCOMPATIBLE + + +class VoidType(ElementType): + """Void-like type which is compatible with everything. + It is a good practice to use this type only as necessary. + For example, when you need template-like functionality. + """ + + def compare(cls, second: abc.ABCMeta) -> NeuralTypeComparisonResult: + return NeuralTypeComparisonResult.SAME + + +# TODO: Consider moving these files elsewhere +class ChannelType(ElementType): + """Element to represent convolutional input/output channel. + """ + + +class EmbeddedTextType(ChannelType): + """Element to represent output on word/text embedding layers + """ + + +class LogitsType(ElementType): + """Element type to represent logits""" + + +class LogprobsType(ElementType): + """Element type to represent log-probabilities. For example, outputs of softmax layers.""" + + +class LabelsType(ElementType): + """Element type to represent some sort of labels. This is often used as a base class to create + a more concrete types such as RegressionValuesType, etc.""" + + +class HypothesisType(LabelsType): + """Element type to represent some decoded hypothesis, which may further be processed to obtain + a concrete label.""" + + +class LengthsType(ElementType): + """Element type representing lengths of something""" + + +class LossType(ElementType): + """Element type to represent outputs of Loss modules""" + + +class EncodedRepresentation(ChannelType): + """Element type to represent encoded representation, for example, encoder's output""" + + +class AcousticEncodedRepresentation(EncodedRepresentation): + """Element type to represent encoded representation returned by the acoustic encoder model""" + + +class AudioSignal(ElementType): + """Element type to represent encoded representation returned by the acoustic encoder model + Args: + freq (int): sampling frequency of a signal. Note that two signals will only be the same if their + freq is the same. + """ + + def __init__(self, freq: int = None): + self._params = {} + self._params['freq'] = freq + + @property + def type_parameters(self): + return self._params + + +class SpectrogramType(ChannelType): + """Element type to represent generic spectrogram signal""" + + +class MelSpectrogramType(SpectrogramType): + """Element type to represent mel spectrogram signal""" + + +class MFCCSpectrogramType(SpectrogramType): + """Element type to represent MFCC spectrogram signal""" + + +class PredictionsType(LabelsType): + """Element type to represent some sort of predictions returned by model""" + + +class RegressionValuesType(PredictionsType): + """Element type to represent labels for regression task""" + + +class CategoricalValuesType(PredictionsType): + """Element type to represent labels for categorical classification task""" + + +class MaskType(PredictionsType): + """Element type to represent a boolean mask""" + + +class Index(ElementType): + """Type representing an element being an index of the sample.""" + + +class Target(ElementType): + """ + Type representing an element being a target value. + """ + + +class ClassificationTarget(Target): + """ + Type representing an element being target value in the classification task, i.e. identifier of a desired class. + """ + + +class ImageValue(ElementType): + """ + Type representing an element/value of a single image channel, + e.g. a single element (R) of RGB image. + """ + + +class NormalizedImageValue(ImageValue): + """ + Type representing an element/value of a single image channel normalized to <0-1> range, + e.g. a single element (R) of normalized RGB image. + """ + + +class ImageFeatureValue(ImageValue): + """Type representing an element (single value) of a (image) feature maps.""" + + +class StringType(ElementType): + """Element type representing a single string""" + + +class StringLabel(StringType): + """ + Type representing an label being a string with class name (e.g. the "hamster" class in CIFAR100). + """ + + +class BoolType(ElementType): + """Element type representing a single integer""" + + +class IntType(ElementType): + """Element type representing a single integer""" + + +class FloatType(ElementType): + """Element type representing a single float""" + + +class TokenIndex(IntType): + """Type representing an element being index of a token in some kind of a vocabulary.""" + + +class Length(IntType): + """Type representing an element storing a "length" (e.g. length of a list).""" + + +class ProbabilityDistributionSamplesType(ElementType): + """Element to represent tensors that meant to be sampled from a valid probability distribution + """ + + +class NormalDistributionSamplesType(ProbabilityDistributionSamplesType): + """Element to represent tensors that meant to be sampled from a valid normal distribution + """ + + +class SequenceToSequenceAlignmentType(ElementType): + """Class to represent the alignment from seq-to-seq attention outputs. Generally a mapping from endcoder time steps + to decoder time steps.""" + + +class NormalDistributionMeanType(ElementType): + """Element to represent the mean of a normal distribution""" + + +class NormalDistributionLogVarianceType(ElementType): + """Element to represent the log variance of a normal distribution""" + + +class TokenDurationType(ElementType): + """Element for representing the duration of a token""" + + +class TokenLogDurationType(ElementType): + """Element for representing the log-duration of a token""" + + +class LogDeterminantType(ElementType): + """Element for representing log determinants usually used in flow models""" diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/neural_type.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/neural_type.py new file mode 100644 index 0000000000000000000000000000000000000000..8714d700b08192c57955425e601d4a029bf6e17c --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/neural_type.py @@ -0,0 +1,223 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +from nemo.core.neural_types.axes import AxisKind, AxisType +from nemo.core.neural_types.comparison import NeuralTypeComparisonResult +from nemo.core.neural_types.elements import ElementType, VoidType + +__all__ = [ + 'NeuralType', + 'NeuralTypeError', + 'NeuralPortNameMismatchError', + 'NeuralPortNmTensorMismatchError', +] + + +class NeuralType(object): + """This is the main class which would represent neural type concept. + It is used to represent *the types* of inputs and outputs. + Args: + axes (Optional[Tuple]): a tuple of AxisTypes objects representing the semantics of what varying each axis means + You can use a short, string-based form here. For example: ('B', 'C', 'H', 'W') would correspond to an NCHW + format frequently used in computer vision. ('B', 'T', 'D') is frequently used for signal processing and + means [batch, time, dimension/channel]. + elements_type (ElementType): an instance of ElementType class representing the semantics of what is stored + inside the tensor. For example: logits (LogitsType), log probabilities (LogprobType), etc. + optional (bool): By default, this is false. If set to True, it would means that input to the port of this + type can be optional. + """ + + def __str__(self): + + if self.axes is not None: + return f"axes: {self.axes}; elements_type: {self.elements_type.__class__.__name__}" + else: + return f"axes: None; elements_type: {self.elements_type.__class__.__name__}" + + def __init__(self, axes: Optional[Tuple] = None, elements_type: ElementType = VoidType(), optional=False): + if not isinstance(elements_type, ElementType): + raise ValueError( + "elements_type of NeuralType must be an instance of a class derived from ElementType. " + "Did you pass a class instead?" + ) + self.elements_type = elements_type + if axes is not None: + NeuralType.__check_sanity(axes) + axes_list = [] + for axis in axes: + if isinstance(axis, str): + axes_list.append(AxisType(AxisKind.from_str(axis), None)) + elif isinstance(axis, AxisType): + axes_list.append(axis) + else: + raise ValueError("axis type must be either str or AxisType instance") + self.axes = tuple(axes_list) + else: + self.axes = None + self.optional = optional + + def compare(self, second) -> NeuralTypeComparisonResult: + """Performs neural type comparison of self with second. When you chain two modules' inputs/outputs via + __call__ method, this comparison will be called to ensure neural type compatibility.""" + # First, handle dimensionality + axes_a = self.axes + axes_b = second.axes + + # "Big void" type + if isinstance(self.elements_type, VoidType) and self.axes is None: + return NeuralTypeComparisonResult.SAME + + if self.axes is None: + if second.axes is None: + return self.elements_type.compare(second.elements_type) + else: + return NeuralTypeComparisonResult.INCOMPATIBLE + + dimensions_pass = NeuralType.__compare_axes(axes_a, axes_b) + element_comparison_result = self.elements_type.compare(second.elements_type) + + # SAME DIMS + if dimensions_pass == 0: + return element_comparison_result + # TRANSPOSE_SAME DIMS + elif dimensions_pass == 1: + if element_comparison_result == NeuralTypeComparisonResult.SAME: + return NeuralTypeComparisonResult.TRANSPOSE_SAME + else: + return NeuralTypeComparisonResult.INCOMPATIBLE + # DIM_INCOMPATIBLE DIMS + elif dimensions_pass == 2: + if element_comparison_result == NeuralTypeComparisonResult.SAME: + return NeuralTypeComparisonResult.DIM_INCOMPATIBLE + else: + return NeuralTypeComparisonResult.INCOMPATIBLE + else: + return NeuralTypeComparisonResult.INCOMPATIBLE + + def compare_and_raise_error(self, parent_type_name, port_name, second_object): + """ Method compares definition of one type with another and raises an error if not compatible. """ + type_comatibility = self.compare(second_object) + if ( + type_comatibility != NeuralTypeComparisonResult.SAME + and type_comatibility != NeuralTypeComparisonResult.GREATER + ): + raise NeuralPortNmTensorMismatchError( + parent_type_name, port_name, str(self), str(second_object.ntype), type_comatibility + ) + + def __eq__(self, other): + if isinstance(other, NeuralType): + return self.compare(other) + + return False + + @staticmethod + def __check_sanity(axes): + # check that list come before any tensor dimension + are_strings = True + for axis in axes: + if not isinstance(axis, str): + are_strings = False + if isinstance(axis, str) and not are_strings: + raise ValueError("Either use full class names or all strings") + if are_strings: + return + checks_passed = True + saw_tensor_dim = False + for axis in axes: + if not axis.is_list: + saw_tensor_dim = True + else: # current axis is a list + if saw_tensor_dim: # which is preceded by tensor dim + checks_passed = False + if not checks_passed: + raise ValueError( + "You have list dimension after Tensor dimension. All list dimensions must preceed Tensor dimensions" + ) + + @staticmethod + def __compare_axes(axes_a, axes_b) -> int: + """ + Compares axes_a and axes_b + Args: + axes_a: first axes tuple + axes_b: second axes tuple + + Returns: + 0 - if they are exactly the same + 1 - if they are "TRANSPOSE_SAME" + 2 - if the are "DIM_INCOMPATIBLE" + 3 - if they are different + """ + if axes_a is None and axes_b is None: + return 0 + elif axes_a is None and axes_b is not None: + return 3 + elif axes_a is not None and axes_b is None: + return 3 + elif len(axes_a) != len(axes_b): + return 3 + # After these ifs we know that len(axes_a) == len(axes_b) + + same = True + kinds_a = dict() + kinds_b = dict() + for axis_a, axis_b in zip(axes_a, axes_b): + kinds_a[axis_a.kind] = axis_a.size + kinds_b[axis_b.kind] = axis_b.size + if axis_a.kind == AxisKind.Any: + same = True + elif ( + axis_a.kind != axis_b.kind + or axis_a.is_list != axis_b.is_list + or (axis_a.size != axis_b.size and axis_a.size is not None) + ): + same = False + if same: + return 0 + else: + # can be TRANSPOSE_SAME, DIM_INCOMPATIBLE + if kinds_a.keys() == kinds_b.keys(): + for key, value in kinds_a.items(): + if kinds_b[key] != value: + return 2 + return 1 + else: + return 3 + + +class NeuralTypeError(Exception): + """Base class for neural type related exceptions.""" + + +class NeuralPortNameMismatchError(NeuralTypeError): + """Exception raised when neural module is called with incorrect port + names.""" + + def __init__(self, input_port_name): + super().__init__() + self.message = "Wrong input port name: {0}".format(input_port_name) + + +class NeuralPortNmTensorMismatchError(NeuralTypeError): + """Exception raised when a port is fed with a NmTensor of incompatible + type.""" + + def __init__(self, class_name, port_name, first_type, second_type, type_comatibility): + super().__init__() + self.message = "\nIn {}. \nPort: {} and a NmTensor it was fed are \n".format(class_name, port_name) + self.message += "of incompatible neural types:\n\n{} \n\n and \n\n{}".format(first_type, second_type) + self.message += "\n\nType comparison result: {}".format(type_comatibility) diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b421582c00347cb61f4d502886942a675d8a944c --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.core.optim.lr_scheduler import ( + CosineAnnealing, + InverseSquareRootAnnealing, + NoamAnnealing, + PolynomialDecayAnnealing, + PolynomialHoldDecayAnnealing, + SquareAnnealing, + SquareRootAnnealing, + WarmupAnnealing, + WarmupHoldPolicy, + WarmupPolicy, + prepare_lr_scheduler, +) +from nemo.core.optim.novograd import Novograd +from nemo.core.optim.optimizers import get_optimizer, parse_optimizer_args, register_optimizer diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/lr_scheduler.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..87df3ea1ebc16be650be7821fb1dcc0da7603b70 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/lr_scheduler.py @@ -0,0 +1,688 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import dataclasses +import math +import warnings +from functools import partial +from typing import Any, Dict, Optional, Union + +import hydra +import torch.optim as optim +import torch.optim.lr_scheduler as pt_scheduler +import torch.utils.data.dataloader as dataloader +from omegaconf import DictConfig, OmegaConf +from torch.optim.lr_scheduler import _LRScheduler + +from nemo.core.config import SchedulerParams, get_scheduler_config, register_scheduler_params +from nemo.utils import logging + + +class WarmupPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, warmup_power=None, max_steps=None, min_lr=0.0, last_epoch=-1): + assert not ( + warmup_steps is not None and warmup_ratio is not None + ), "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.warmup_power = warmup_power + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + if step <= self.warmup_steps: + lr_val = (step + 1) / (self.warmup_steps + 1) + if self.warmup_power: + lr_val = lr_val ** self.warmup_power + + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class WarmupHoldPolicy(WarmupPolicy): + """Variant of WarmupPolicy which maintains high learning rate for a defined number of steps. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + warmup_power=None, + hold_steps=None, + hold_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (hold_steps is not None and hold_ratio is not None), "Either use particular number of step or ratio" + assert hold_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + self.min_lr = min_lr + self._last_warmup_lr = 0.0 + + # Necessary to duplicate as class attributes are hidden in inner class + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if hold_steps is not None: + self.hold_steps = hold_steps + self.warmup_steps + elif hold_ratio is not None: + self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps + else: + self.hold_steps = 0 + + super().__init__( + optimizer, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + warmup_power=warmup_power, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + ) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + # Warmup phase + if step <= self.warmup_steps: + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + # Hold phase + if (step >= self.warmup_steps) and (step < self.hold_steps): + return self.base_lrs + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + +def _squareroot_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps) ** 0.5 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _square_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps) ** 2 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _cosine_annealing(initial_lr, step, max_steps, min_lr): + mult = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + out_lr = (initial_lr - min_lr) * mult + min_lr + return out_lr + + +def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): + if cycle: + multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps) + decay_steps *= multiplier + else: + step = min(step, decay_steps) + p = step / decay_steps + lr = (initial_lr - min_lr) * math.pow(1.0 - p, power) + lr += min_lr + return lr + + +class SquareAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _square_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class SquareRootAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _squareroot_annealing(initial_lr=initial_lr, step=step, max_steps=self.max_steps, min_lr=self.min_lr) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class CosineAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate that was lower than the minimum learning rate." + ) + + new_lrs = [ + _cosine_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class NoamAnnealing(_LRScheduler): + def __init__( + self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + ): + self._normalize = d_model ** (-0.5) + assert not ( + warmup_steps is not None and warmup_ratio is not None + ), "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = max(1, self.last_epoch) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate that was lower than the minimum learning rate." + ) + + new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for initial_lr in self.base_lrs] + return new_lrs + + def _noam_annealing(self, initial_lr, step): + mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5))) + out_lr = initial_lr * mult + if step > self.warmup_steps: + out_lr = max(out_lr, self.min_lr) + return out_lr + + +class WarmupAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + progress = float(step / self.max_steps) + warmup_ratio = float(self.warmup_steps / self.max_steps) + + mult = max((progress - 1.0) / (warmup_ratio - 1.0), 0.0) + out_lr = [initial_lr * mult for initial_lr in self.base_lrs] + + return out_lr + + +class InverseSquareRootAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr) + + def _get_lr(self, step): + denom = ((step + 1) / (self.warmup_steps + 1)) ** 0.5 + out_lr = [initial_lr / denom for initial_lr in self.base_lrs] + return out_lr + + +class PolynomialDecayAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs): + self.power = power + self.cycle = cycle + + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _poly_decay( + initial_lr, + step=step - self.warmup_steps, + decay_steps=self.max_steps - self.warmup_steps, + power=self.power, + min_lr=self.min_lr, + cycle=self.cycle, + ) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class PolynomialHoldDecayAnnealing(WarmupHoldPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs): + self.power = power + self.cycle = cycle + + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _poly_decay( + initial_lr, + step=step - self.hold_steps, + decay_steps=self.max_steps - max(self.warmup_steps, self.hold_steps), + power=self.power, + min_lr=self.min_lr, + cycle=self.cycle, + ) + for initial_lr in self.base_lrs + ] + return new_lrs + + +def register_scheduler(name: str, scheduler: _LRScheduler, scheduler_params: SchedulerParams): + """ + Checks if the scheduler name exists in the registry, and if it doesnt, adds it. + + This allows custom schedulers to be added and called by name during instantiation. + + Args: + name: Name of the optimizer. Will be used as key to retrieve the optimizer. + scheduler: Scheduler class (inherits from _LRScheduler) + scheduler_params: The parameters as a dataclass of the scheduler + """ + if name in AVAILABLE_SCHEDULERS: + raise ValueError(f"Cannot override pre-existing schedulers. Conflicting scheduler name = {name}") + + AVAILABLE_SCHEDULERS[name] = scheduler + + sched_name = "{}_params".format(scheduler.__name__) + register_scheduler_params(name=sched_name, scheduler_params=scheduler_params) + + +def get_scheduler(name: str, **kwargs: Optional[Dict[str, Any]]) -> _LRScheduler: + """ + Convenience method to obtain an _LRScheduler class and partially instantiate it with optimizer kwargs. + + Args: + name: Name of the scheduler in the registry. + kwargs: Optional kwargs of the scheduler used during instantiation. + + Returns: + a partially instantiated _LRScheduler + """ + if name not in AVAILABLE_SCHEDULERS: + raise ValueError( + f"Cannot resolve scheduler{name}'. Available optimizers are : " f"{AVAILABLE_SCHEDULERS.keys()}" + ) + + scheduler_cls = AVAILABLE_SCHEDULERS[name] + scheduler = partial(scheduler_cls, **kwargs) + return scheduler + + +def prepare_lr_scheduler( + optimizer: optim.Optimizer, + scheduler_config: Union[Dict[str, Any], DictConfig], + train_dataloader: Optional[dataloader.DataLoader] = None, +) -> Optional[Dict[str, Any]]: + """ + Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema + + optim: + name: + lr: + + # + args: + name: auto # special keyword, resolves to correct optimizer config for given optimizer name + # cls: nemo.core.config.optimizers.NovogradParams # explicit instantiation by class path + params: # optional override parameters for the optimizer config + betas: [0.8, 0.5] + weight_decay: 0.001 + + # scheduler setup + sched: + name: + iters_per_batch: null # computed at runtime; mandatory to have + max_steps: null # computed at runtime or explicitly set here; mandatory to have + + # pytorch lightning args + monitor: val_loss + reduce_on_plateau: false + + # + args: + name: auto # special keyword, resolves to correct optimizer config for given optimizer name + # cls: nemo.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path + params: # optional override parameters for the optimizer config + warmup_steps: null + warmup_ratio: null + min_lr: 0.0 + last_epoch: -1 + + Args: + optimizer: An instantiated Optimizer. + scheduler_config: A dictionary / config dict which follows the above schema. + train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined + instead of "max_steps". Used to compute effective "max_steps". + + Returns: + A dictionary containing the LR Scheduler implementation if the config was successfully parsed + along with other parameters required by Pytorch Lightning, otherwise None. + """ + # Build nested dictionary for convenience out of structured objects + if isinstance(scheduler_config, DictConfig): + scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True) + + elif dataclasses.is_dataclass(scheduler_config): + # Recursively transform data classes to basic dictionaries + scheduler_config = OmegaConf.create(scheduler_config) + scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True) + + # Test to see if config follows above schema + + add_max_args_flag = True + interval = 'step' + if scheduler_config is not None: + if 'args' in scheduler_config: + scheduler_args = scheduler_config.pop('args') + else: + scheduler_args = copy.deepcopy(scheduler_config) + + # Remove extra parameters from scheduler_args nest + # Assume all other parameters are to be passed into scheduler constructor + + if 'name' in scheduler_args and scheduler_args['name'] == 'ReduceLROnPlateau': + add_max_args_flag = False + interval = 'epoch' + + scheduler_args.pop('name', None) + scheduler_args.pop('t_max_epochs', None) + scheduler_args.pop('t_accumulate_grad_batches', None) + scheduler_args.pop('t_limit_train_batches', None) + scheduler_args.pop('t_num_workers', None) + scheduler_args.pop('monitor', None) + scheduler_args.pop('reduce_on_plateau', None) + + else: + # Return gracefully in case `sched` was not supplied; inform user + logging.info('Scheduler not initialized as no `sched` config supplied to setup_optimizer()') + return None + + # Try instantiation of scheduler params from config class path + if 'name' not in scheduler_config: + scheduler_args_cfg = OmegaConf.create(scheduler_args) + scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg) + scheduler_args = vars(scheduler_conf) + + # Get name of the scheduler + scheduler_name = scheduler_conf.__class__.__name__ + + if 'Params' in scheduler_name: + scheduler_name = scheduler_name.replace('Params', '') + else: + # Class path instantiation failed; try resolving "name" component + + # Get name of the scheduler + if 'name' in scheduler_config: + scheduler_name = scheduler_config['name'] + else: + logging.warning( + "Could not resolve classpath for Scheduler Config, and `name` " + "was not provided either. \n" + "Scheduler cannot be instantiated !" + ) + return None + + # If class path was not provided, perhaps `name` is provided for resolution + if 'name' in scheduler_args: + # If `auto` is passed as name for resolution of optimizer name, + # then lookup optimizer name and resolve its parameter config + if scheduler_args['name'] == 'auto': + scheduler_params_name = "{}Params".format(scheduler_name) + else: + scheduler_params_name = scheduler_args['name'] + + # Get override arguments provided in the config yaml file / Dict Config + scheduler_params_override = scheduler_args.get('params', {}) + + # If params is itself a dict config object provided explicitly in Dict Config + # Resolve to dictionary for convenience + if isinstance(scheduler_params_override, DictConfig): + scheduler_params_override = OmegaConf.to_container(scheduler_params_override, resolve=True) + + # Get and instantiate the Config dataclass for this scheduler + scheduler_params_cls = get_scheduler_config(scheduler_params_name, **scheduler_params_override) + scheduler_params = scheduler_params_cls() # instantiate the parameters object + scheduler_args = vars(scheduler_params) # extract just the dictionary from the Config object + + else: + # assume the input dictionary is schedular args (from dataclasses / omegaconf) + pass + + # Extract value to monitor in losses, if provided. + if 'monitor' in scheduler_config: + monitor = scheduler_config.get('monitor') + else: + # Default to train loss + monitor = 'loss' + + # Store exact max_steps if it is provided + if 'max_steps' in scheduler_config and scheduler_config['max_steps'] is not None: + max_steps = scheduler_config['max_steps'] + + elif 't_max_epochs' in scheduler_config: + # Compute effective max_steps if t_max_epochs is provided + if train_dataloader is None: + logging.warning( + 'As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n' + 'to compute effective maximum number of steps.\n' + 'Scheduler will not be instantiated !' + ) + return None + + # Raise exception if neither `max_steps` nor `t_max_epochs` is provided + if scheduler_config.get('t_max_epochs', None) is None: + logging.warning( + "`t_max_epochs` cannot be None when `max_steps` is not not provided.\n" + "This can occur when `train dataloader` is not available to correctly " + "prepare the scheduler.\n" + "Scheduler will not be instantiated !" + ) + return None + + # Get iters_per_batch + max_epochs = scheduler_config.get('t_max_epochs') + accumulate_grad_batches = scheduler_config.get('t_accumulate_grad_batches') + limit_train_batches = scheduler_config.get('t_limit_train_batches') + num_workers = scheduler_config.get('t_num_workers') + + # Compute effective num max_steps + num_samples = len(train_dataloader.dataset) + batch_size = train_dataloader.batch_size + drop_last = train_dataloader.drop_last + + max_steps = compute_max_steps( + max_epochs=max_epochs, + accumulate_grad_batches=accumulate_grad_batches, + limit_train_batches=limit_train_batches, + num_workers=num_workers, + num_samples=num_samples, + batch_size=batch_size, + drop_last=drop_last, + ) + + else: + logging.warning( + "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, " + "cannot compute effective `max_steps` !\n" + "Scheduler will not be instantiated !" + ) + return None + + # Inject max_steps (effective or provided) into the scheduler config + if add_max_args_flag: + scheduler_args['max_steps'] = max_steps + + # Get the scheduler class from the config + scheduler_cls = get_scheduler(scheduler_name, **scheduler_args) + + # Instantiate the LR schedule + schedule = scheduler_cls(optimizer, **scheduler_args) + + logging.info( + 'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)', + str(schedule), + max_steps, + OmegaConf.to_yaml(OmegaConf.create(scheduler_args)), + ) + + # Wrap the schedule in PTL arguments to perform stepwise computation + # Rather than epoch level computation + if isinstance(schedule, optim.lr_scheduler.ReduceLROnPlateau): + reduce_lr_on_plateau = True + else: + reduce_lr_on_plateau = False + + schedule_dict = { + 'scheduler': schedule, + 'interval': interval, + 'frequency': 1, + 'monitor': monitor, + 'reduce_on_plateau': reduce_lr_on_plateau, + } + return schedule_dict + + +def compute_max_steps( + max_epochs, accumulate_grad_batches, limit_train_batches, num_workers, num_samples, batch_size, drop_last +): + _round = math.floor if drop_last else math.ceil + + sampler_num_samples = math.ceil(num_samples / num_workers) + + if drop_last and num_workers > 1: + logging.warning( + "Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released" + ) + # TODO: Master verion, not in pytorch 1.6.0 + # sampler_num_samples = math.ceil((num_samples - num_workers)/ num_workers) + + steps_per_epoch = _round(sampler_num_samples / batch_size) + if isinstance(limit_train_batches, int) or limit_train_batches == 0.0: + steps_per_epoch = min(steps_per_epoch, int(limit_train_batches)) + elif steps_per_epoch != float('inf'): + # limit_train_batches is a percentage of batches per epoch + steps_per_epoch = int(steps_per_epoch * limit_train_batches) + if accumulate_grad_batches == 1: + steps_per_epoch = max(steps_per_epoch, 1) + + return math.ceil(steps_per_epoch / accumulate_grad_batches) * max_epochs + + +AVAILABLE_SCHEDULERS = { + 'WarmupPolicy': WarmupPolicy, + 'WarmupHoldPolicy': WarmupHoldPolicy, + 'SquareAnnealing': SquareAnnealing, + 'CosineAnnealing': CosineAnnealing, + 'NoamAnnealing': NoamAnnealing, + 'WarmupAnnealing': WarmupAnnealing, + 'InverseSquareRootAnnealing': InverseSquareRootAnnealing, + 'SquareRootAnnealing': SquareRootAnnealing, + 'PolynomialDecayAnnealing': PolynomialDecayAnnealing, + 'PolynomialHoldDecayAnnealing': PolynomialHoldDecayAnnealing, + 'StepLR': pt_scheduler.StepLR, + 'ExponentialLR': pt_scheduler.ExponentialLR, + 'ReduceLROnPlateau': pt_scheduler.ReduceLROnPlateau, + 'CyclicLR': pt_scheduler.CyclicLR, +} diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/novograd.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/novograd.py new file mode 100644 index 0000000000000000000000000000000000000000..34415a505e48f947baab5740a86686a63cbd1215 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/novograd.py @@ -0,0 +1,198 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.optim.optimizer import Optimizer + +__all__ = ['Novograd'] + + +def _check_valid_opt_params(lr, eps, betas): + if lr < 0: + raise ValueError(f"Invalid learning rate: {lr}") + if eps < 0: + raise ValueError(f"Invalid epsilon value: {eps}") + if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0): + raise ValueError(f"Betas have to be between 0 and 1: {betas}") + + +class Novograd(Optimizer): + """Implements Novograd algorithm. + It has been proposed in "Stochastic Gradient Methods with Layer-wise + Adaptive Moments for Training of Deep Networks" + (https://arxiv.org/abs/1905.11286) + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and Beyond" + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.95, 0.98), + eps=1e-8, + eps_in_sqrt=False, + weight_decay=0, + weight_decay_ema=True, + grad_averaging=False, + amsgrad=False, + luc=False, + luc_grad_trust=0.0, + luc_grad_trust_rel=False, + luc_trust=1e-3, + luc_trust_min=0.0, + luc_eps=1e-8, + luc_update_min=1e-7, + luc_update_max=1.0, + ): + _check_valid_opt_params(lr, eps, betas) + assert isinstance(eps_in_sqrt, bool) + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad, + eps_in_sqrt=eps_in_sqrt, + luc=luc, + luc_grad_trust=luc_grad_trust, + luc_grad_trust_rel=luc_grad_trust_rel, + luc_trust=luc_trust, + luc_trust_min=luc_trust_min, + luc_eps=luc_eps, + luc_update_min=luc_update_min, + luc_update_max=luc_update_max, + weight_decay_ema=weight_decay_ema + ) + super(Novograd, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Novograd, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError("Sparse gradients are not supported.") + + amsgrad = group["amsgrad"] + state = self.state[p] + + # State initialization + if not state: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device) + if amsgrad: + # Maintains max of all exp moving avg of squared grad + state["max_exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + if group['luc'] and group['luc_grad_trust'] > 0: + if not group['luc_grad_trust_rel']: + # Clip grad so that grad are less than eta*weights + luc_factor = get_luc_factor(p.data, grad, luc_trust=group['luc_grad_trust'], luc_trust_min=0.0) + grad.mul_(luc_factor) + else: + if exp_avg_sq != 0: + luc_factor = get_luc_factor(exp_avg_sq.sqrt(), grad, luc_trust=group['luc_grad_trust'], luc_trust_min=0.0) + grad.mul_(luc_factor) + + norm = grad.norm().pow(2) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(norm, alpha=1.0 - beta2) + + if amsgrad: + # Maintains max of all 2nd moment running avg till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max for normalizing running avg. of gradient + if not group['eps_in_sqrt']: + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = max_exp_avg_sq.add_(group["eps"]).sqrt() + else: + if not group['eps_in_sqrt']: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = exp_avg_sq.add_(group["eps"]).sqrt() + + grad.div_(denom) + if group["weight_decay"] != 0 and group['weight_decay_ema']: + grad.add_(p.data, alpha=group["weight_decay"]) + if group["grad_averaging"]: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + update = exp_avg + if group["weight_decay"] != 0 and not group['weight_decay_ema']: + update = update.add(p.data, alpha=group["weight_decay"]) + + lr = group["lr"] + if group['luc'] and group['luc_trust'] > 0: + # Clip lr so that updates are less than eta*weights + luc_factor = get_luc_factor(p.data, update.data, luc_trust=group['luc_trust'], luc_trust_min=group['luc_trust_min']) + lr = luc_factor * lr + + p.data.add_(update, alpha=-lr) + + return loss + + +def get_luc_factor(param, grad, *, luc_trust, luc_trust_min): + param_norm = torch.norm(param) + param_norm = max(param_norm, 1e-3) + grad_norm = torch.norm(grad) + if grad_norm == 0: + return 1.0 + + max_grad_norm = param_norm * luc_trust + min_grad_norm = param_norm * luc_trust_min + if grad_norm > max_grad_norm: + luc_factor = max_grad_norm / grad_norm + elif grad_norm < min_grad_norm: + luc_factor = min_grad_norm / grad_norm + else: + luc_factor = 1.0 + + return luc_factor diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/optimizers.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..bc326589d15fb2227d0a6a1650747028c7a8a833 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/optimizers.py @@ -0,0 +1,164 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from functools import partial +from typing import Any, Dict, List, Optional, Union + +import hydra +import torch.optim as optim +from omegaconf import DictConfig, OmegaConf +from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop +from torch.optim.optimizer import Optimizer + +from nemo.core.config import OptimizerParams, get_optimizer_config, register_optimizer_params +from nemo.core.optim.novograd import Novograd + +__all__ = ['get_optimizer', 'register_optimizer', 'parse_optimizer_args'] + + +AVAILABLE_OPTIMIZERS = { + 'sgd': optim.SGD, + 'adam': optim.Adam, + 'adamw': optim.AdamW, + # 'adadelta': adadelta.Adadelta, + # 'adamax': adamax.Adamax, + # 'adagrad': adagrad.Adagrad, + # 'rmsprop': rmsprop.RMSprop, + # 'rprop': rprop.Rprop, + 'novograd': Novograd, +} + + +def parse_optimizer_args( + optimizer_name: str, optimizer_kwargs: Union[DictConfig, Dict[str, Any]] +) -> Union[Dict[str, Any], DictConfig]: + """ + Parses a list of strings, of the format "key=value" or "key2=val1,val2,..." + into a dictionary of type {key=value, key2=[val1, val2], ...} + + This dictionary is then used to instantiate the chosen Optimizer. + + Args: + optimizer_name: string name of the optimizer, used for auto resolution of params + optimizer_kwargs: Either a list of strings in a specified format, + or a dictionary. If a dictionary is provided, it is assumed the dictionary + is the final parsed value, and simply returned. + If a list of strings is provided, each item in the list is parsed into a + new dictionary. + + Returns: + A dictionary + """ + kwargs = {} + + if optimizer_kwargs is None: + return kwargs + + optimizer_kwargs = copy.deepcopy(optimizer_kwargs) + + if isinstance(optimizer_kwargs, DictConfig): + optimizer_kwargs = OmegaConf.to_container(optimizer_kwargs, resolve=True) + + # If it is a dictionary, perform stepwise resolution + if hasattr(optimizer_kwargs, 'keys'): + # Attempt class path resolution + try: + optimizer_kwargs_config = OmegaConf.create(optimizer_kwargs) + optimizer_instance = hydra.utils.instantiate(optimizer_kwargs_config) # type: DictConfig + optimizer_instance = vars(optimizer_instance) + return optimizer_instance + except Exception: + pass + + # If class path was not provided, perhaps `name` is provided for resolution + if 'name' in optimizer_kwargs: + # If `auto` is passed as name for resolution of optimizer name, + # then lookup optimizer name and resolve its parameter config + if optimizer_kwargs['name'] == 'auto': + optimizer_params_name = "{}_params".format(optimizer_name) + optimizer_kwargs.pop('name') + else: + optimizer_params_name = optimizer_kwargs.pop('name') + + # Override arguments provided in the config yaml file + if 'params' in optimizer_kwargs: + # If optimizer kwarg overrides are wrapped in yaml `params` + optimizer_params_override = optimizer_kwargs.get('params') + else: + # If the kwargs themselves are a DictConfig + optimizer_params_override = optimizer_kwargs + + if isinstance(optimizer_params_override, DictConfig): + optimizer_params_override = OmegaConf.to_container(optimizer_params_override, resolve=True) + + optimizer_params_cls = get_optimizer_config(optimizer_params_name, **optimizer_params_override) + + # If we are provided just a Config object, simply return the dictionary of that object + if optimizer_params_name is None: + optimizer_params = vars(optimizer_params_cls) + return optimizer_params + + else: + # If we are provided a partial class instantiation of a Config, + # Instantiate it and retrieve its vars as a dictionary + optimizer_params = optimizer_params_cls() # instantiate the parameters object + optimizer_params = vars(optimizer_params) + return optimizer_params + + # simply return the dictionary that was provided + return optimizer_kwargs + + return kwargs + + +def register_optimizer(name: str, optimizer: Optimizer, optimizer_params: OptimizerParams): + """ + Checks if the optimizer name exists in the registry, and if it doesnt, adds it. + + This allows custom optimizers to be added and called by name during instantiation. + + Args: + name: Name of the optimizer. Will be used as key to retrieve the optimizer. + optimizer: Optimizer class + optimizer_params: The parameters as a dataclass of the optimizer + """ + if name in AVAILABLE_OPTIMIZERS: + raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}") + + AVAILABLE_OPTIMIZERS[name] = optimizer + + optim_name = "{}_params".format(optimizer.__name__) + register_optimizer_params(name=optim_name, optimizer_params=optimizer_params) + + +def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer: + """ + Convenience method to obtain an Optimizer class and partially instantiate it with optimizer kwargs. + + Args: + name: Name of the Optimizer in the registry. + kwargs: Optional kwargs of the optimizer used during instantiation. + + Returns: + a partially instantiated Optimizer + """ + if name not in AVAILABLE_OPTIMIZERS: + raise ValueError( + f"Cannot resolve optimizer '{name}'. Available optimizers are : " f"{AVAILABLE_OPTIMIZERS.keys()}" + ) + + optimizer = AVAILABLE_OPTIMIZERS[name] + optimizer = partial(optimizer, **kwargs) + return optimizer diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/package_info.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/package_info.py new file mode 100644 index 0000000000000000000000000000000000000000..3843cf41a6090d636866e565aab1907af94dcf22 --- /dev/null +++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/package_info.py @@ -0,0 +1,35 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +MAJOR = 1 +MINOR = 0 +PATCH = 0 +PRE_RELEASE = 'b4' + +# Use the following formatting: (major, minor, patch, pre-release) +VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) + +__shortversion__ = '.'.join(map(str, VERSION[:3])) +__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:]) + +__package_name__ = 'nemo_toolkit' +__contact_names__ = 'NVIDIA' +__contact_emails__ = 'nemo-toolkit@nvidia.com' +__homepage__ = 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/' +__repository_url__ = 'https://github.com/nvidia/nemo' +__download_url__ = 'https://github.com/NVIDIA/NeMo/releases' +__description__ = 'NeMo - a toolkit for Conversational AI' +__license__ = 'Apache2' +__keywords__ = 'deep learning, machine learning, gpu, NLP, NeMo, nvidia, pytorch, torch, tts, speech, language' diff --git a/MMaDA/models/speech_tokenization/UVITS/attentions.py b/MMaDA/models/speech_tokenization/UVITS/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..383c5da5c34103003a973aab97cc39180db35fda --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/attentions.py @@ -0,0 +1,313 @@ +import copy +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +import commons +import modules +from modules import LayerNorm + + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, + **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, + window_size=window_size)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., + proximal_bias=False, proximal_init=True, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, + proximal_bias=proximal_bias, proximal_init=proximal_init)) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, + block_length=None, proximal_bias=False, proximal_init=False): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels ** -0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, + causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/MMaDA/models/speech_tokenization/UVITS/commons.py b/MMaDA/models/speech_tokenization/UVITS/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..970489852841b0350f945b10e1c6e572860e9da8 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/commons.py @@ -0,0 +1,161 @@ +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (num_timescales - 1)) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm diff --git a/MMaDA/models/speech_tokenization/UVITS/data_utils.py b/MMaDA/models/speech_tokenization/UVITS/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8621250a8d201b0c3e19b24b6955bfdd2cd1503d --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/data_utils.py @@ -0,0 +1,398 @@ +import time +import os +import random +import numpy as np +import torch +import torch.utils.data + +import commons +from mel_processing import spectrogram_torch +from utils import load_wav_to_torch, load_filepaths_and_text +from text import text_to_sequence, cleaned_text_to_sequence +import librosa + + +class TextAudioLoader(torch.utils.data.Dataset): + """ + 1) loads audio, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, audiopaths_and_text, hparams, is_training=True): + self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) + self.text_cleaners = hparams.text_cleaners + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + + self.cleaned_text = getattr(hparams, "cleaned_text", False) + + self.add_blank = hparams.add_blank + self.min_text_len = getattr(hparams, "min_text_len", 1) + self.max_text_len = getattr(hparams, "max_text_len", 190) + + if is_training: + random.seed(1234) + random.shuffle(self.audiopaths_and_text) + self._filter() + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + + audiopaths_and_text_new = [] + lengths = [] + for audiopath, text in self.audiopaths_and_text: + if self.min_text_len <= len(text) and len(text) <= self.max_text_len: + if os.path.exists(audiopath): + audiopaths_and_text_new.append([audiopath, text]) + lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) + self.audiopaths_and_text = audiopaths_and_text_new + self.lengths = lengths + + def get_audio_text_pair(self, audiopath_and_text): + # separate filename and text + audiopath, text = audiopath_and_text[0], audiopath_and_text[1] + text = self.get_text(text) + spec, wav = self.get_audio(audiopath) + return (text, spec, wav) + + def get_audio(self, filename): + audio, sampling_rate = librosa.load(filename, sr=self.sampling_rate) + audio_norm = torch.FloatTensor(audio.astype(np.float32)).unsqueeze(0) + + spec = spectrogram_torch(audio_norm, self.filter_length, + self.sampling_rate, self.hop_length, self.win_length, + center=False) + + return spec, audio_norm + + def get_text(self, text): + if self.cleaned_text: + text_norm = cleaned_text_to_sequence(text) + else: + text_norm = text_to_sequence(text, self.text_cleaners) + if self.add_blank: + text_norm = commons.intersperse(text_norm, 0) + text_norm = torch.LongTensor(text_norm) + return text_norm + + def __getitem__(self, index): + return self.get_audio_text_pair(self.audiopaths_and_text[index]) + + def __len__(self): + return len(self.audiopaths_and_text) + + +class TextAudioCollate(): + """ Zero-pads model inputs and targets + """ + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text and aduio + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[1].size(1) for x in batch]), + dim=0, descending=True) + + max_text_len = max([len(x[0]) for x in batch]) + max_spec_len = max([x[1].size(1) for x in batch]) + max_wav_len = max([x[2].size(1) for x in batch]) + + text_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + wav_lengths = torch.LongTensor(len(batch)) + + text_padded = torch.LongTensor(len(batch), max_text_len) + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + text_padded.zero_() + spec_padded.zero_() + wav_padded.zero_() + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + text = row[0] + text_padded[i, :text.size(0)] = text + text_lengths[i] = text.size(0) + + spec = row[1] + spec_padded[i, :, :spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wav = row[2] + wav_padded[i, :, :wav.size(1)] = wav + wav_lengths[i] = wav.size(1) + + if self.return_ids: + return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, ids_sorted_decreasing + return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths + + +"""Multi speaker version""" + + +class TextAudioSpeakerLoader(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, audiopaths_sid_text, hparams, is_training=True): + self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text, hparams, is_training=is_training) + self.text_cleaners = hparams.text_cleaners + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + + self.cleaned_text = getattr(hparams, "cleaned_text", False) + + self.add_blank = hparams.add_blank + self.min_text_len = getattr(hparams, "min_text_len", 1) + self.max_text_len = getattr(hparams, "max_text_len", 190) + + if is_training: + random.seed(1234) + random.shuffle(self.audiopaths_sid_text) + self._filter() + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + + audiopaths_sid_text_new = [] + lengths = [] + filters = 0 + for audiopath, sid, text in self.audiopaths_sid_text: + if not self.cleaned_text and 'text_split' in self.text_cleaners: + text_ = text.split() + if self.min_text_len <= len(text_) and len(text_) <= self.max_text_len and os.path.exists(audiopath): + audiopaths_sid_text_new.append([audiopath, sid, text]) + lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) + else: + filters += 1 + else: + if self.min_text_len <= len(text) and len(text) <= self.max_text_len: + audiopaths_sid_text_new.append([audiopath, sid, text]) + lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) + else: + filters += 1 + self.audiopaths_sid_text = audiopaths_sid_text_new + self.lengths = lengths + print(f"Filter out {filters} files") + + def get_audio_text_speaker_pair(self, audiopath_sid_text): + # separate filename, speaker_id and text + audiopath, sid, text = audiopath_sid_text[0], audiopath_sid_text[1], audiopath_sid_text[2] + text = self.get_text(text) + spec, wav = self.get_audio(audiopath) + sid = self.get_sid(sid) + return (text, spec, wav, sid) + + def get_audio(self, filename): + audio, sampling_rate = librosa.load(filename, sr=self.sampling_rate) + audio_norm = torch.FloatTensor(audio.astype(np.float32)).unsqueeze(0) + + spec = spectrogram_torch(audio_norm, self.filter_length, + self.sampling_rate, self.hop_length, self.win_length, + center=False) + spec = torch.squeeze(spec, 0) + return spec, audio_norm + + def get_text(self, text): + if self.cleaned_text: + text_norm = cleaned_text_to_sequence(text) + else: + text_norm = text_to_sequence(text, self.text_cleaners) + if self.add_blank: + text_norm = commons.intersperse(text_norm, 0) + text_norm = torch.LongTensor(text_norm) + return text_norm + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def __getitem__(self, index): + return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + + def __len__(self): + return len(self.audiopaths_sid_text) + + +class TextAudioSpeakerCollate(): + """ Zero-pads model inputs and targets + """ + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text, audio and speaker identities + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized, sid] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[1].size(1) for x in batch]), + dim=0, descending=True) + + max_text_len = max([len(x[0]) for x in batch]) + max_spec_len = max([x[1].size(1) for x in batch]) + max_wav_len = max([x[2].size(1) for x in batch]) + + text_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + wav_lengths = torch.LongTensor(len(batch)) + sid = torch.LongTensor(len(batch)) + + text_padded = torch.LongTensor(len(batch), max_text_len) + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + text_padded.zero_() + spec_padded.zero_() + wav_padded.zero_() + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + text = row[0] + text_padded[i, :text.size(0)] = text + text_lengths[i] = text.size(0) + + spec = row[1] + spec_padded[i, :, :spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wav = row[2] + wav_padded[i, :, :wav.size(1)] = wav + wav_lengths[i] = wav.size(1) + + sid[i] = row[3] + + if self.return_ids: + return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing + return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid + + +class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): + """ + Maintain similar input lengths in a batch. + Length groups are specified by boundaries. + Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. + + It removes samples which are not included in the boundaries. + Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. + """ + + def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + self.lengths = dataset.lengths + self.batch_size = batch_size + self.boundaries = boundaries + + self.buckets, self.num_samples_per_bucket = self._create_buckets() + self.total_size = sum(self.num_samples_per_bucket) + self.num_samples = self.total_size // self.num_replicas + + def _create_buckets(self): + buckets = [[] for _ in range(len(self.boundaries) - 1)] + for i in range(len(self.lengths)): + length = self.lengths[i] + idx_bucket = self._bisect(length) + if idx_bucket != -1: + buckets[idx_bucket].append(i) + + for i in range(len(buckets) - 1, -1, -1): # second term debugged from 0 to -1 + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + + num_samples_per_bucket = [] + for i in range(len(buckets)): + len_bucket = len(buckets[i]) + total_batch_size = self.num_replicas * self.batch_size + rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size + num_samples_per_bucket.append(len_bucket + rem) + return buckets, num_samples_per_bucket + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] + + # subsample + ids_bucket = ids_bucket[self.rank::self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + + def _bisect(self, x, lo=0, hi=None): + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 + + def __len__(self): + return self.num_samples // self.batch_size diff --git a/MMaDA/models/speech_tokenization/UVITS/mel_processing.py b/MMaDA/models/speech_tokenization/UVITS/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..572fb4e9e7306c0675f423fbcf6349780c402017 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/mel_processing.py @@ -0,0 +1,114 @@ +import math +import os +import random +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.data +import numpy as np +import librosa +import librosa.util as librosa_util +from librosa.util import normalize, pad_center, tiny +from scipy.signal import get_window +from scipy.io.wavfile import read +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + '_' + str(spec.device) + fmax_dtype_device = str(fmax) + '_' + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + fmax_dtype_device = str(fmax) + '_' + dtype_device + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/MMaDA/models/speech_tokenization/UVITS/models.py b/MMaDA/models/speech_tokenization/UVITS/models.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7db0b97805a2a3ecfbdc14cd818669248ef67f --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/models.py @@ -0,0 +1,674 @@ +import copy +import math +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Variable + +import commons +import modules +import attentions +import monotonic_align + +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from commons import init_weights, get_padding +from script import vae_utils + + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, + gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, + upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MelStyleEncoder(nn.Module): + ''' MelStyleEncoder ''' + + def __init__(self, in_dim, style_hidden, style_vector_dim, style_kernel_size, style_head, dropout, + mode='determinant'): + super(MelStyleEncoder, self).__init__() + self.in_dim = in_dim + self.hidden_dim = style_hidden + self.out_dim = style_vector_dim + self.kernel_size = style_kernel_size + self.n_head = style_head + self.dropout = dropout + self.mode = mode + + self.spectral = nn.Sequential( + modules.LinearNorm(self.in_dim, self.hidden_dim), + modules.Mish(), + nn.Dropout(self.dropout), + modules.LinearNorm(self.hidden_dim, self.hidden_dim), + modules.Mish(), + nn.Dropout(self.dropout) + ) + + self.temporal = nn.Sequential( + modules.Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), + modules.Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), + ) + + self.slf_attn = modules.MultiHeadAttention(self.n_head, self.hidden_dim, + self.hidden_dim // self.n_head, self.hidden_dim // self.n_head, + self.dropout) + self.fc = modules.LinearNorm(self.hidden_dim, self.out_dim) + + def temporal_avg_pool(self, x, mask=None): + if mask is None: + out = torch.mean(x, dim=1) + else: + len_ = (~mask).sum(dim=1).unsqueeze(1) + x = x.masked_fill(mask.unsqueeze(-1), 0) + x = x.sum(dim=1) + out = torch.div(x, len_) + return out + + def forward(self, x, mask=None): + + max_len = x.shape[1] + if mask is not None: + mask = (mask.int() == 0).squeeze(1) + slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) + else: + slf_attn_mask = None + # spectral + x = self.spectral(x) + # temporal + x = x.transpose(1, 2) + x = self.temporal(x) + x = x.transpose(1, 2) + # self-attention + if mask is not None: + x = x.masked_fill(mask.unsqueeze(-1), 0) + x, _ = self.slf_attn(x, mask=slf_attn_mask) + # fc + x = self.fc(x) + # temoral average pooling + w = self.temporal_avg_pool(x, mask=mask) + w = F.normalize(w, dim=1) + return w + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + use_ref_enc=False, + ref_mode='determinant', + **kwargs): + + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.use_sdp = use_sdp + self.use_ref_enc = use_ref_enc + + self.enc_p = TextEncoder(n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, + upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, + gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + + self.ref_mode = ref_mode + if n_speakers > 1: + if ref_mode == 'entropy': + self.emb_g = nn.Parameter( + torch.randn(n_speakers, gin_channels * 2), + requires_grad=True) + else: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + if use_ref_enc: + self.enc_r = MelStyleEncoder(spec_channels, hidden_channels, gin_channels, 5, 2, p_dropout, mode=ref_mode) + + def forward(self, x, x_lengths, y, y_lengths, sid=None): + + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + g, spk_centroid = None, None + y_mask = None + aux = {} + if self.n_speakers > 0: + if self.use_ref_enc: + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) # [b, 1, t] + aux['y_mask'] = y_mask + + if self.ref_mode == 'determinant': + g = self.enc_r(y.transpose(1, 2), y_mask).unsqueeze(-1) # [b, h, 1] + spk_centroid = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + aux['spk_centroid'] = spk_centroid + + elif self.ref_mode == 'entropy': + spk_mean, spk_var = vae_utils.gaussian_parameters(self.emb_g) # p^2 + mean, var = spk_mean[sid], spk_var[sid] + + g = self.enc_r(y.transpose(1, 2), y_mask) + g = g.unsqueeze(-1) + + aux['spk_mean'] = mean + aux['spk_var'] = var + else: + raise NotImplementedError + else: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), + s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + if self.n_speakers > 0: + return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), g, aux + else: + return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, spec=None, + spec_lengths=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + g = None + if sid is not None: + if self.n_speakers > 0: + if self.use_ref_enc: + assert spec is not None + spec_mask = torch.unsqueeze(commons.sequence_mask(spec_lengths, spec.size(2)), 1).to(spec.dtype) + g = self.enc_r(spec.transpose(1, 2), spec_mask).unsqueeze(-1) + else: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + if self.use_sdp: + logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, + 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:, :, :max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def get_speaker_style_embedding(self, spec, spec_lengths): + assert self.use_ref_enc + spec_mask = torch.unsqueeze(commons.sequence_mask(spec_lengths, spec.size(2)), 1).to(spec.dtype) + g = self.enc_r(spec.transpose(1, 2), spec_mask) + return g + + def synthesis_from_content_unit_style_embedding(self, x, x_lengths,style_embedding, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, ): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + + g=style_embedding # should of shape [b, h, 1] + + if self.use_sdp: + logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, + 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:, :, :max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) diff --git a/MMaDA/models/speech_tokenization/UVITS/modules.py b/MMaDA/models/speech_tokenization/UVITS/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b4dd724eea79b871136732f5957ae5f065cf7c8a --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/modules.py @@ -0,0 +1,569 @@ +import copy +import math +import numpy as np +import scipy +import torch +from torch import nn +from torch.nn import functional as F + +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm + +import commons +from commons import init_weights, get_padding +from transforms import piecewise_rational_quadratic_transform + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.GAMMA = nn.Parameter(torch.ones(channels)) + self.BETA = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.GAMMA, self.BETA, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, + groups=channels, dilation=dilation, padding=padding + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert (kernel_size % 2 == 1) + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, :self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, + gin_channels=gin_channels) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) + self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins:] + + x1, logabsdet = piecewise_rational_quadratic_transform(x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails='linear', + tail_bound=self.tail_bound + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class LinearNorm(nn.Module): + def __init__(self, + in_channels, + out_channels, + bias=True, + spectral_norm=False, + ): + super(LinearNorm, self).__init__() + self.fc = nn.Linear(in_channels, out_channels, bias) + + if spectral_norm: + self.fc = nn.utils.spectral_norm(self.fc) + + def forward(self, input): + out = self.fc(input) + return out + + +class Mish(nn.Module): + def __init__(self): + super(Mish, self).__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class LinearNorm(nn.Module): + def __init__(self, + in_channels, + out_channels, + bias=True, + spectral_norm=False, + ): + super(LinearNorm, self).__init__() + self.fc = nn.Linear(in_channels, out_channels, bias) + + if spectral_norm: + self.fc = nn.utils.spectral_norm(self.fc) + + def forward(self, input): + out = self.fc(input) + return out + + +class ConvNorm(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + spectral_norm=False, + ): + super(ConvNorm, self).__init__() + + if padding is None: + assert (kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + def forward(self, input): + out = self.conv(input) + return out + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0., spectral_norm=False): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + + self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout) + + self.fc = nn.Linear(n_head * d_v, d_model) + self.dropout = nn.Dropout(dropout) + + if spectral_norm: + self.w_qs = nn.utils.spectral_norm(self.w_qs) + self.w_ks = nn.utils.spectral_norm(self.w_ks) + self.w_vs = nn.utils.spectral_norm(self.w_vs) + self.fc = nn.utils.spectral_norm(self.fc) + + def forward(self, x, mask=None): + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_x, _ = x.size() + + residual = x + + q = self.w_qs(x).view(sz_b, len_x, n_head, d_k) + k = self.w_ks(x).view(sz_b, len_x, n_head, d_k) + v = self.w_vs(x).view(sz_b, len_x, n_head, d_v) + q = q.permute(2, 0, 1, 3).contiguous().view(-1, + len_x, d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, + len_x, d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, + len_x, d_v) # (n*b) x lv x dv + + if mask is not None: + slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + else: + slf_mask = None + output, attn = self.attention(q, k, v, mask=slf_mask) + + output = output.view(n_head, sz_b, len_x, d_v) + output = output.permute(1, 2, 0, 3).contiguous().view( + sz_b, len_x, -1) # b x lq x (n*dv) + + output = self.fc(output) + + output = self.dropout(output) + residual + return output, attn + + +class ScaledDotProductAttention(nn.Module): + ''' Scaled Dot-Product Attention ''' + + def __init__(self, temperature, dropout): + super().__init__() + self.temperature = temperature + self.softmax = nn.Softmax(dim=2) + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + + if mask is not None: + attn = attn.masked_fill(mask, -np.inf) + + attn = self.softmax(attn) + p_attn = self.dropout(attn) + + output = torch.bmm(p_attn, v) + return output, attn + + +class Conv1dGLU(nn.Module): + ''' + Conv1d + GLU(Gated Linear Unit) with residual connection. + For GLU refer to https://arxiv.org/abs/1612.08083 paper. + ''' + + def __init__(self, in_channels, out_channels, kernel_size, dropout): + super(Conv1dGLU, self).__init__() + self.out_channels = out_channels + self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.conv1(x) + x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) + x = x1 * torch.sigmoid(x2) + x = residual + self.dropout(x) + return x diff --git a/MMaDA/models/speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/config.json b/MMaDA/models/speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/config.json new file mode 100644 index 0000000000000000000000000000000000000000..5c935b729734c5b41edcd286f1669e0c68434151 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/config.json @@ -0,0 +1,91 @@ +{ + "train": { + "log_interval": 200, + "eval_interval": 1000, + "seed": 1234, + "epochs": 10000, + "learning_rate": 2e-4, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-9, + "batch_size": 32, + "fp16_run": false, + "lr_decay": 0.999875, + "segment_size": 8192, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "c_spk": 0.0 + }, + "data": { + "training_files": "/home/ma-user/work/daxintan/data/for_U2S_training/xujing_cosyvoice_40ms_multilingual_8888/EN_CH_female_male/xujing_cosyvoice_EN_CH_female_male_wav_sid_reduced_unit.txt", + "validation_files": "/home/ma-user/work/daxintan/data/for_U2S_training/LibriTTS_AISHELL1_40ms_multilingual_8888/LibriTTS_test-clean_AISHELL1_test_wav_sid_reduced_unit.txt", + "text_cleaners": [ + "text_split" + ], + "max_wav_value": 32768.0, + "sampling_rate": 22050, + "filter_length": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mel_channels": 80, + "mel_fmin": 0.0, + "mel_fmax": null, + "add_blank": false, + "n_speakers": 1344, + "cleaned_text": false, + "max_text_len": 1000 + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 4, + 4 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 256, + "use_ref_enc": true + } +} diff --git a/MMaDA/models/speech_tokenization/UVITS/my_synthesis/my_synthesis_for_speech_unit_sequence_recombination.py b/MMaDA/models/speech_tokenization/UVITS/my_synthesis/my_synthesis_for_speech_unit_sequence_recombination.py new file mode 100644 index 0000000000000000000000000000000000000000..49b5c246de8a08223bfa2ca588bee5fa2871dbe3 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/my_synthesis/my_synthesis_for_speech_unit_sequence_recombination.py @@ -0,0 +1,47 @@ +from pathlib import Path + +import torch +from scipy.io.wavfile import write +from tqdm import tqdm + +import utils +from models.speech_tokenization.UVITS.models import SynthesizerTrn +from text import text_to_sequence + +import random +import shutil +import librosa +import numpy as np +import json + +from data_utils import load_wav_to_torch, spectrogram_torch + + +def get_audio(filename, hps): + audio, sampling_rate = librosa.load(filename, sr=hps.sampling_rate) + audio_norm = torch.FloatTensor(audio.astype(np.float32)).unsqueeze(0) + spec = spectrogram_torch(audio_norm, hps.filter_length, + hps.sampling_rate, hps.hop_length, hps.win_length, + center=False) + spec = torch.squeeze(spec, 0) + return spec, audio_norm + + +def get_text_unit_list(TTS_input_output_file): + text_list, unit_seq_list = [], [] + with open(TTS_input_output_file, 'r') as f: + for line in f: + if line.startswith('input:'): + text_list.append(line.lstrip('input:').rstrip('\n').strip()) + elif line.startswith('output: '): + unit_seq_list.append(line.lstrip('output:').rstrip('\n').strip()) + return text_list, unit_seq_list + + +def get_U2S_config_checkpoint_file(unit_type, language='English'): + assert language in ['English', 'Chinese'] + assert unit_type =='40ms_multilingual_8888_xujing_cosyvoice_FT' + # English and Chinese using the same UVITS model and config!! + config_file = "./speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/config.json" + checkpoint_file = "./speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/saved_checkpoint/G_322000.pth" + return config_file, checkpoint_file \ No newline at end of file diff --git a/MMaDA/models/speech_tokenization/UVITS/script/grad_multiply.py b/MMaDA/models/speech_tokenization/UVITS/script/grad_multiply.py new file mode 100644 index 0000000000000000000000000000000000000000..dec0ccfe395abdcbc03d53cc3dce9873c44586e8 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/script/grad_multiply.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +def grad_multiply_wrapper(module, rate): + class GMModel(torch.nn.Module): + def __init__(self): + super(GMModel, self).__init__() + self.module = module + self.rate = rate + + def forward(self, *args, **kwargs): + if self.rate > 0: + features = self.module(*args, **kwargs) + if self.rate != 1.0: + if isinstance(features, torch.Tensor): + features = GradMultiply.apply(features, self.rate) + elif isinstance(features, tuple): + features = (GradMultiply.apply(f, self.rate) for f in features) + elif isinstance(features, dict): + features = {k: GradMultiply.apply(f, self.rate) for k, f in features.items()} + else: + with torch.no_grad(): + features = module(*args, **kwargs) + return features + + return GMModel() + + +if __name__ == '__main__': + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.a = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.a * x + + + m = M() + m = grad_multiply_wrapper(m, 0.3) + optim = torch.optim.SGD(m.parameters(), lr=0.1, momentum=0) + optim.zero_grad() + loss = m(torch.ones(1)) + loss.backward() + optim.step() + print(m.module.a) diff --git a/MMaDA/models/speech_tokenization/UVITS/script/html_generation.py b/MMaDA/models/speech_tokenization/UVITS/script/html_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..5ebeb9dcd162fef5e998531d257a3f252a080885 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/script/html_generation.py @@ -0,0 +1,232 @@ +""" +Author: zhengnianzu +Place: shenzhen +Time: 2020.8.25 +Update: 2022.10.17 +""" +from ast import arg +from distutils.command.config import config +from inspect import ArgSpec +import os +import sys +import base64 +import argparse +import shutil + + +def add_description(title, description): + return f"""

    {title}{description}

    """ + + +def add_header(header_list): + head_str = "\n" + for i, h in enumerate(header_list): + if i == 0: + head_str += f"""\t {h} \n""" # + continue + head_str += f"""\t{h}\n""" + head_str += "\n" + return head_str + + +def add_row(text, wav_list): + row_str = f"""{text}\n""" + for wav_path in wav_list: + row_str += f"""\t""" + \ + f""" \n""" + row_str += "\n" + return row_str + + +def read_configs(config): + content = [] + for line in open(config): + ln = line.strip().split("|") + content.append([v.strip() for v in ln]) + return content + + +def make_table(config_text=None, config_list=None, max_rows=20, data_dir=None): + if data_dir is None: + input_dir = os.path.dirname(config_text) + else: + input_dir = data_dir + + if config_list is None: + content = read_configs(config_text) + else: + content = config_list + + data = os.path.join(input_dir, 'data') + if not os.path.exists(data): + os.makedirs(data, exist_ok=True) + + table_str = "" + table_str += add_header(content[0]) + table_str += "" + row_number = 1 + for row_line in content[1:]: + nw = [] + for w in row_line[1:]: + n = os.path.basename(w) + nn = os.path.join(data, n) + shutil.copyfile(w, nn) + n_ = os.path.join('data', n) + nw.append(n_) + + table_str += add_row(row_line[0], nw) + + row_number += 1 + if row_number > max_rows: + break + + table_str += "" + table_str += "
    " + return table_str + + +def wav_to_html(file_path): + try: + with open(file_path, 'rb') as bf: + bf_data = bf.read() + base64_data = base64.b64encode(bf_data) # get base64 string + base64_message = base64_data.decode("utf-8") + return "data:audio/wav;base64, {}".format(base64_message) + except Exception as e: + return '' + + +def add_embed_row(text, wav_list, wavesurfer=True, row_number=1): + row_str = f"""{text}\n""" + i = 1 + for wav_path in wav_list: + wav_html = wav_to_html(wav_path) + if wav_html is not None: + if not wavesurfer: + row_str += f"""\t""" + \ + f""" \n""" + else: + indx = f"{row_number}_{i}" + row_str += f"""\t""" + \ + f"""{wavesurfer_cell(wav_path, indx)}""" + \ + f""" \n""" + + else: + row_str += f"""\t \n""" + i += 1 + row_str += "\n" + return row_str + + +def wavesurfer_cell(fpath, i=1): + wstr = wav_to_html(fpath) + wstr = f""" +
    + + + """ + return wstr + + +def make_embed_table(config_text=None, config_list=None, wavesurfer=False, max_rows=100): + if config_list is None: + content = content = read_configs(config_text) + else: + content = config_list + + table_str = "" + table_str += add_header(content[0]) + table_str += "" + row_number = 1 + for row_line in content[1:]: + table_str += add_embed_row(row_line[0], row_line[1:], wavesurfer=wavesurfer, row_number=row_number) + row_number += 1 + if row_number > max_rows: + break + + table_str += "" + table_str += "
    " + return table_str + + +def make_html(args): + if not args.split: + tbStr = make_embed_table(config_text=args.config, wavesurfer=args.wave_surf, max_rows=args.max_rows) + else: + tbStr = make_table(config_text=args.config, max_rows=args.max_rows, data_dir=os.path.dirname(args.name)) + + template = """\ + + + Audio Samples + + + + + + """ + template += """ + + +

    Audio Samples

    +
    +

    {}

    +
    + + + """.format(tbStr) + + with open(args.name, "w") as g: + g.write(template) + + +if __name__ == "__main__": + """ + args: + embd + html_name + data_data + """ + parser = argparse.ArgumentParser() + parser.add_argument('-n', '--name', type=str, default='eval.html') + parser.add_argument('-s', '--split', action='store_true') + parser.add_argument('-c', '--config', type=str, required=True) + parser.add_argument('--wave-surf', action='store_true') + parser.add_argument('--max_rows', type=int, default=100) + args = parser.parse_args() + make_html(args) diff --git a/MMaDA/models/speech_tokenization/UVITS/script/html_index.py b/MMaDA/models/speech_tokenization/UVITS/script/html_index.py new file mode 100644 index 0000000000000000000000000000000000000000..42b899dc1f9d5c0110f0a48745fb3d53b99d61de --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/script/html_index.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass +import os +from argparse import ArgumentParser + + +def main(conf): + texts = [v.strip().split('|') for v in open(conf.text_file)] + if conf.n_speakers > 0: + texts = {k: v for k, _, v in texts} + else: + texts = {k: v for k, v in texts} + + with open(conf.input_file) as f, open(conf.index_file, 'w') as g: + g.write("Text| VITS-Ref| Ground Truth\n") + for idx, li in enumerate(f): + n, *_ = li.strip().split('|') + text = texts[n] + w1, w2 = os.path.join(conf.audio_dir, f'{idx}_pred.wav'), \ + os.path.join(conf.audio_dir, f'{idx}_gt.wav') + g.write(f'{text}|{w1}|{w2}\n') + + +def vc_main(conf): + texts = [v.strip().split('|') for v in open(conf.text_file)] + texts = {k: v for k, _, v in texts} + + with open(conf.input_file) as f, open(conf.index_file, 'w') as g: + g.write("Text| Source | Conversion(F) | Target(F) | Conversion(M) | Target(M) \n") + for idx, li in enumerate(f): + n, *_ = li.strip().split('|') + text = texts[n] + sr, vc1, tar1, vc2, tar2 = os.path.join(conf.audio_dir, f'{idx}_src.wav'), \ + os.path.join(conf.audio_dir, f'{idx}_vc0.wav'), \ + os.path.join(conf.audio_dir, f'{idx}_tar0.wav'), \ + os.path.join(conf.audio_dir, f'{idx}_vc1.wav'), \ + os.path.join(conf.audio_dir, f'{idx}_tar1.wav') + g.write(f'{text}|{sr}|{vc1}|{tar1}|{vc2}|{tar2}\n') + + +if __name__ == '__main__': + + @dataclass + class Conf: + text_file: str = 'filelists/vctk_audio_sid_text_test_filelist.txt' + input_file: str = 'filelists/vctk_audio_sid_text_test_filelist.txt.unit.reduced' + index_file: str = 'result/index.txt' + audio_dir: str = 'result/spiral-20ms/base' + + + # conf = Conf() + + parser = ArgumentParser() + parser.add_argument('--text_file', default=None) + parser.add_argument('--input_file', default=None) + parser.add_argument('--vc_file', default=None) + parser.add_argument('--index_file', type=str, required=True) + parser.add_argument('--audio_dir', type=str, required=True) + parser.add_argument('--n_speakers', type=int, default=109) + args = parser.parse_args() + + if args.vc_file is None: + main(args) + else: + vc_main(args) + + + @dataclass + class VCConf: + text_file: str = 'filelists/vctk_audio_sid_text_test_filelist.txt' + input_file: str = 'filelists/vctk_audio_sid_text_test_filelist.txt.unit.reduced' + vc_file: str = 'filelists/vctk_vc_pairs.txt' + index_file: str = 'result/index-vc.txt' + audio_dir: str = 'result/spiral-20ms/base-vc' + + + @dataclass + class CmuJVCConf: + text_file: str = 'filelists/cmu_arctic_audio_sid_text_test_filelist.txt' + input_file: str = 'filelists/cmu_arctic_audio_sid_text_test_filelist.unit.reduced' + vc_file: str = 'filelists/cmu_vctk_vc_pairs.txt' + index_file: str = 'result/index-vc-a2m.txt' + audio_dir: str = 'result/spiral-20ms/base-vc-a2m' + + + @dataclass + class CmuVCConf: + text_file: str = 'filelists/cmu_arctic_audio_sid_text_test_filelist.txt' + input_file: str = 'filelists/cmu_arctic_audio_sid_text_test_filelist.unit.reduced' + vc_file: str = 'filelists/cmu_vc_pairs.txt' + index_file: str = 'result/index-vc-a2a.txt' + audio_dir: str = 'result/spiral-20ms/base-vc-a2a' + + + @dataclass + class CmuVCConfSC: + text_file: str = 'filelists/cmu_arctic_audio_sid_text_test_filelist.txt' + input_file: str = 'filelists/cmu_arctic_audio_sid_text_test_filelist.unit.reduced' + vc_file: str = 'filelists/cmu_vc_pairs.txt' + index_file: str = 'result/index-sc-vc-a2a.txt' + audio_dir: str = 'result/spiral-20ms/sc-vc-a2a' + + # conf = VCConf() + # conf = CmuJVCConf() + # conf = CmuVCConf() + # conf = CmuVCConfSC() diff --git a/MMaDA/models/speech_tokenization/UVITS/script/vae_utils.py b/MMaDA/models/speech_tokenization/UVITS/script/vae_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae07fd912977c6a6aaa9dd22ad427d64180b0749 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/script/vae_utils.py @@ -0,0 +1,225 @@ +""" +Author:zhengnianzu +From repo: https://github.com/divymurli/VAEs +Time: 2022.10.31 +Place: Shenzhen +""" + +import torch +import numpy as np +import torch.nn.functional as F + + +def sample_gaussian(m, v): + """ + Element-wise application reparameterization trick to sample from Gaussian + + Args: + m: tensor: (batch, ...): Mean + v: tensor: (batch, ...): Variance + + Return: + z: tensor: (batch, ...): Samples + """ + ################################################################################ + # TODO: Modify/complete the code here + # Sample z + ################################################################################ + + ################################################################################ + # End of code modification + ################################################################################ + sample = torch.randn(m.shape).to(v.device) + + z = m + (v ** 0.5) * sample + return z + + +def log_normal(x, m, v): + """ + Computes the elem-wise log probability of a Gaussian and then sum over the + last dim. Basically we're assuming all dims are batch dims except for the + last dim. + + Args: + x: tensor: (batch, ..., dim): Observation + m: tensor: (batch, ..., dim): Mean + v: tensor: (batch, ..., dim): Variance + + Return: + kl: tensor: (batch1, batch2, ...): log probability of each sample. Note + that the summation dimension (dim=-1) is not kept + """ + ################################################################################ + # TODO: Modify/complete the code here + # Compute element-wise log probability of normal and remember to sum over + # the last dimension + ################################################################################ + # print("q_m", m.size()) + # print("q_v", v.size()) + const = -0.5 * x.size(-1) * torch.log(2 * torch.tensor(np.pi)) + # print(const.size()) + log_det = -0.5 * torch.sum(torch.log(v), dim=-1) + # print("log_det", log_det.size()) + log_exp = -0.5 * torch.sum((x - m) ** 2 / v, dim=-1) + + log_prob = const + log_det + log_exp + + ################################################################################ + # End of code modification + ################################################################################ + return log_prob + + +def log_normal_mixture(z, m, v, w=None): + """ + Computes log probability of a uniformly-weighted Gaussian mixture. + + Args: + z: tensor: (batch, dim): Observations + m: tensor: (batch, mix, dim): Mixture means + v: tensor: (batch, mix, dim): Mixture variances + + Return: + log_prob: tensor: (batch,): log probability of each sample + """ + ################################################################################ + # TODO: Modify/complete the code here + # Compute the uniformly-weighted mixture of Gaussians density for each sample + # in the batch + ################################################################################ + z = z.unsqueeze(1) + log_probs = log_normal(z, m, v) + # print("log_probs_mix", log_probs.shape) + + if w is not None: + log_prob = log_weighted_sum_exp(log_probs, w, 1) + else: + log_prob = log_mean_exp(log_probs, 1) + + # print("log_prob_mix", log_prob.size()) + + ################################################################################ + # End of code modification + ################################################################################ + return log_prob + + +def gaussian_parameters(h, dim=-1): + """ + Converts generic real-valued representations into mean and variance + parameters of a Gaussian distribution + + Args: + h: tensor: (batch, ..., dim, ...): Arbitrary tensor + dim: int: (): Dimension along which to split the tensor for mean and + variance + + Returns:z + m: tensor: (batch, ..., dim / 2, ...): Mean + v: tensor: (batch, ..., dim / 2, ...): Variance + """ + m, h = torch.split(h, h.size(dim) // 2, dim=dim) + v = F.softplus(h) + 1e-8 + return m, v + + +def kl_normal(qm, qv, pm, pv): + """ + Computes the elem-wise KL divergence between two normal distributions KL(q || p) and + sum over the last dimension + + Args: + qm: tensor: (batch, dim): q mean + qv: tensor: (batch, dim): q variance + pm: tensor: (batch, dim): p mean + pv: tensor: (batch, dim): p variance + + Return: + kl: tensor: (batch,): kl between each sample + """ + element_wise = 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1) + kl = element_wise.sum(-1) + # print("log var1", qv) + return kl + + +def log_mean_exp(x, dim): + """ + Compute the log(mean(exp(x), dim)) in a numerically stable manner + + Args: + x: tensor: (...): Arbitrary tensor + dim: int: (): Dimension along which mean is computed + + Return: + _: tensor: (...): log(mean(exp(x), dim)) + """ + return log_sum_exp(x, dim) - np.log(x.size(dim)) + + +def log_weighted_sum_exp(x, w, dim=-1): + """ + compute the log(weighted sum(exp(x), dim)) + """ + max_x = torch.max(x, dim)[0] + new_x = x - max_x.unsqueeze(dim).expand_as(x) + return max_x + (new_x.exp().mul(w).sum(dim)).log() + + +def log_sum_exp(x, dim=0): + """ + Compute the log(sum(exp(x), dim)) in a numerically stable manner + + Args: + x: tensor: (...): Arbitrary tensor + dim: int: (): Dimension along which sum is computed + + Return: + _: tensor: (...): log(sum(exp(x), dim)) + """ + max_x = torch.max(x, dim)[0] + new_x = x - max_x.unsqueeze(dim).expand_as(x) + return max_x + (new_x.exp().sum(dim)).log() + + +def gaussian_mixture_parameters(buffer: torch.Tensor): + """Speaker prior. + Args: + buffer: [torch.float32; [K, E x 2 + 1]], distribution weights. + Returns: + weight: [torch.float32; [K]], weights of each modals. + mean: [torch.float32; [K, E]], mean vectors. + std: [torch.float32; [K, E]], standard deviations. + """ + # [K] + weight = torch.softmax(buffer[..., 0], dim=0) + # [K, E], [K, E] + mean, logstd = buffer[..., 1:].chunk(2, dim=-1) + # [K, E] + std = F.softplus(logstd) + # [K], [K, E], [K, E] + return weight, mean, std + + +class Weight_Scheduler(object): + def __init__(self, base_wt, n_warmup_steps, update_step, power): + """ + warmup and update every update_step + """ + self.base_wt = base_wt + self.n_warmup_steps = n_warmup_steps + self.power = power + self.update_step = update_step + + def _get_wt(self, n_current_steps): + if self.n_warmup_steps != 0 and n_current_steps <= self.n_warmup_steps and n_current_steps % self.update_step == 0: + scale = np.power(self.n_warmup_steps, -(1 + self.power)) * n_current_steps + elif n_current_steps > self.n_warmup_steps and n_current_steps % self.update_step == 0: + scale = np.power(n_current_steps, -self.power) + else: + scale = 0 + return scale * self.base_wt + + def _get_max_wt(self): + return np.power(self.n_warmup_steps, -self.power) diff --git a/MMaDA/models/speech_tokenization/UVITS/text/LICENSE b/MMaDA/models/speech_tokenization/UVITS/text/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4ad4ed1d5e34d95c8380768ec16405d789cc6de4 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/text/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2017 Keith Ito + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/MMaDA/models/speech_tokenization/UVITS/text/__init__.py b/MMaDA/models/speech_tokenization/UVITS/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8127013427b84ba89ece252521237b93426a7391 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/text/__init__.py @@ -0,0 +1,53 @@ +""" from https://github.com/keithito/tacotron """ +from text import cleaners +from text.symbols import symbols + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + + +def text_to_sequence(text, cleaner_names): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + return sequence + + +def cleaned_text_to_sequence(cleaned_text): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + + +def sequence_to_text(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text diff --git a/MMaDA/models/speech_tokenization/UVITS/text/cleaners.py b/MMaDA/models/speech_tokenization/UVITS/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..661cdb93178a1ebe6577db4810237aafddc64b64 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/text/cleaners.py @@ -0,0 +1,104 @@ +""" from https://github.com/keithito/tacotron """ + +''' +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +''' + +import re +from unidecode import unidecode +from phonemizer import phonemize + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize(text, language='en-us', backend='espeak', strip=True) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +def english_cleaners2(text): + '''Pipeline for English text, including abbreviation expansion. + punctuation + stress''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, preserve_punctuation=True, + with_stress=True) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +def text_split(text): + return text.split() diff --git a/MMaDA/models/speech_tokenization/UVITS/text/symbols.py b/MMaDA/models/speech_tokenization/UVITS/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..e6964382c61edd9eeddcccd95c9b95e69aee01c3 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/text/symbols.py @@ -0,0 +1,21 @@ +""" from https://github.com/keithito/tacotron """ + +''' +Defines the set of symbols used in text input to the model. +''' +_pad = '_' +_punctuation = ';:,.!?”¿—…"Ā«Ā»ā€œā€ ' +_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +_letters_ipa = "É‘ÉÉ’Ć¦É“Ź™Ī²É”É•Ć§É—É–Ć°Ź¤É™É˜ÉšÉ›ÉœÉÉžÉŸŹ„É”É É¢Ź›É¦É§Ä§É„ŹœÉØÉŖŹÉ­É¬É«É®ŹŸÉ±ÉÆÉ°Å‹É³É²É“ĆøÉµÉøĪøÅ“É¶Ź˜É¹ÉŗÉ¾É»Ź€ŹÉ½Ź‚ŹƒŹˆŹ§Ź‰ŹŠŹ‹ā±±ŹŒÉ£É¤ŹĻ‡ŹŽŹŹ‘ŹŹ’Ź”Ź”Ź•Ź¢Ē€ĒĒ‚ĒƒĖˆĖŒĖĖ‘Ź¼Ź“Ź°Ź±Ź²Ź·Ė Ė¤Ėžā†“ā†‘ā†’ā†—ā†˜'Ģ©'įµ»" + +# Export all symbols: Todo: should write more elegently!! +# symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + [str(v) for v in range(1024)] # 1024 by daxin +# symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + [str(v) for v in range(2048)] # 2048 by dehua +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + [str(v) for v in range(4096)] # 4096 by daxin + +symbols_with_1024 = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + [str(v) for v in range(1024)] # 1024 by daxin +symbols_with_2048 = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + [str(v) for v in range(2048)] # 2048 by dehua +symbols_with_4096 = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + [str(v) for v in range(4096)] # 4096 by daxin + +# Special symbol ids +SPACE_ID = symbols.index(" ") diff --git a/MMaDA/models/speech_tokenization/UVITS/transforms.py b/MMaDA/models/speech_tokenization/UVITS/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..12dd72776a6b788932329da169155a289eb645bc --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/transforms.py @@ -0,0 +1,192 @@ +import torch +from torch.nn import functional as F + +import numpy as np + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = { + 'tails': tails, + 'tail_bound': tail_bound + } + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum( + inputs[..., None] >= bin_locations, + dim=-1 + ) - 1 + + +def unconstrained_rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails='linear', + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == 'linear': + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError('{} tails are not implemented.'.format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative + ) + + return outputs, logabsdet + + +def rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0., right=1., bottom=0., top=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError('Input to a transform is not within its domain') + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError('Minimal bin width too large for the number of bins') + if min_bin_height * num_bins > 1.0: + raise ValueError('Minimal bin height too large for the number of bins') + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (((inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta) + + input_heights * (input_delta - input_derivatives))) + b = (input_heights * input_derivatives + - (inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta)) + c = - input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/MMaDA/models/speech_tokenization/UVITS/utils.py b/MMaDA/models/speech_tokenization/UVITS/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a38a7befb5f089fa66e81b67f3945eda838435f5 --- /dev/null +++ b/MMaDA/models/speech_tokenization/UVITS/utils.py @@ -0,0 +1,287 @@ +import os +import glob +import sys +import argparse +import logging +import json +import subprocess +import numpy as np +from scipy.io.wavfile import read +import torch + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging + + +def load_checkpoint(checkpoint_path, model, optimizer=None): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') + iteration = checkpoint_dict['iteration'] + learning_rate = checkpoint_dict['learning_rate'] + if optimizer is not None: + optimizer.load_state_dict(checkpoint_dict['optimizer']) + saved_state_dict = checkpoint_dict['model'] + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] if 'GAMMA' not in k and 'BETA' not in k else saved_state_dict[k.lower()] + except: + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v + if hasattr(model, 'module'): + model.module.load_state_dict(new_state_dict) + else: + model.load_state_dict(new_state_dict) + logger.info("Loaded checkpoint '{}' (iteration {})".format( + checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info("Saving model and optimizer state at iteration {} to {}".format( + iteration, checkpoint_path)) + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save({'model': state_dict, + 'iteration': iteration, + 'optimizer': optimizer.state_dict(), + 'learning_rate': learning_rate}, checkpoint_path) + + +def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats='HWC') + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', + interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + + +def load_filepaths_and_text_(filename, split="|"): + with open(filename, encoding='utf-8') as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def load_filepaths_and_text(filenames, hparams=None, split='|', is_training=True): + if isinstance(filenames, list): + each_speakers = hparams.get("each_speakers", None) + if is_training: + assert len(each_speakers) == len(filenames) + filepaths_and_text = [] + spk_count = 0 + for i, filename in enumerate(filenames): + temp = load_filepaths_and_text_(filename, split=split) + + if each_speakers is not None: + # unified speaker coding problem + temp = [(li[0], str(int(li[1]) + spk_count), li[2]) for li in temp] + spk_count += each_speakers[i] + + filepaths_and_text.extend(temp) + + return filepaths_and_text + + else: + return load_filepaths_and_text_(filenames, split=split) + + +def get_hparams(init=True, argv=None): + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default="./configs/base.json", + help='JSON file for configuration') + parser.add_argument('-m', '--model', type=str, required=True, + help='Model name') + + args = parser.parse_args(args=argv) + model_dir = args.model + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + config_path = args.config + config_save_path = os.path.join(model_dir, "config.json") + if init: + with open(config_path, "r") as f: + data = f.read() + with open(config_save_path, "w") as f: + f.write(data) + else: + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + )) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warn("git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8])) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +class HParams(): + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + def get(self, key, default_value): + if key not in self.keys(): + return default_value + else: + return self[key] diff --git a/MMaDA/models/speech_utils.py b/MMaDA/models/speech_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6aafe7bfd18f7e9d2e3ef588140f45b0b03a062f --- /dev/null +++ b/MMaDA/models/speech_utils.py @@ -0,0 +1,166 @@ +import os +import numpy as np +from scipy.io.wavfile import write +import torch +import random +from importlib import import_module +from omegaconf import OmegaConf + +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), "./speech_tokenization/UVITS")) +import utils +from models.speech_tokenization.UVITS.models import SynthesizerTrn +from text import text_to_sequence +from my_synthesis.my_synthesis_for_speech_unit_sequence_recombination import get_U2S_config_checkpoint_file +sys.path.append(os.path.join(os.path.dirname(__file__), "./speech_tokenization/SPIRAL_L2_BN_FSQ_CTC")) +from my_extract_unit_for_speech.extract_unit_construct_wav_unit_text import \ + get_S2U_ckpt_config_path, sample_extract_unit, batch_extract_unit +from nemo.collections.asr.models.spec2vec.vq_ctc_finetune import VQCTCFinetuneModel +from nemo.utils import logging + +################# +# S2U +################# +def load_config(config=None): + if config is not None: + print("Config: ", config) + cfg_module = import_module(config.replace('/', '.')) + cfg = OmegaConf.structured(cfg_module.cfg) + OmegaConf.set_struct(cfg, True) + return cfg + +def load_S2U_model(ckpt_path, config_path, model_name): + assert model_name in ['SPIRAL-FSQ-CTC'] + cfg = load_config(config=config_path) + cfg.model.pretrain_chkpt_path = None + model = VQCTCFinetuneModel(cfg=cfg.model, trainer=None).eval() + model = model.to(dtype=torch.float32) + checkpoint = torch.load(os.path.join(os.path.dirname(__file__), ckpt_path), map_location='cpu') + missing_keys, unexpected_keys = model.load_state_dict(checkpoint['state_dict'], strict=False) + if(missing_keys): + logging.warning('Missing Keys: {}'.format(missing_keys)) + if(unexpected_keys): + logging.warning('Unexpected Keys: {}'.format(unexpected_keys)) + return model + +def s2u_extract_unit_demo(model, wav_path, model_name, reduced=True): + assert model_name in ['SPIRAL-FSQ-CTC'] + + wav_file_list = [wav_path] + wav_file_list_len = 1 + extracted_wav_file_list, skipped_wav_file_list, unreduced_unit_sequence_list, reduced_unit_sequence_list = batch_extract_unit(wav_file_list, model, max_chunk=960000) + target_unit_sequence_list = reduced_unit_sequence_list if reduced else unreduced_unit_sequence_list + if len(extracted_wav_file_list) != 0: + target_unit_sequence = target_unit_sequence_list[0] + else: + wav_file = skipped_wav_file_list[0] + unreduced_unit_sequence, reduced_unit_sequence = sample_extract_unit(wav_file, model) + target_unit_sequence = reduced_unit_sequence if reduced else unreduced_unit_sequence + + return "".join(["<|speech_{}|>".format(each) for each in target_unit_sequence.split(" ")]) + +################# +# U2S +################# +def load_condition_centroid(condition2style_centroid_file): + with open(os.path.join(os.path.dirname(__file__), condition2style_centroid_file), 'r') as f: + line_list = [line.replace('\n', '') for line in f] + assert line_list[0] == 'condition|style_centroid_file' + condition2style_centroid_file_dict, condition2style_centroid_embedding_dict = {}, {} + for line in line_list[1:]: + condition, style_centroid_file = line.split('|') + condition2style_centroid_file_dict[condition] = style_centroid_file + style_centroid_embedding = np.load(os.path.join(os.path.dirname(__file__), style_centroid_file)) + style_centroid_embedding = torch.FloatTensor(style_centroid_embedding).unsqueeze(1).unsqueeze(0) + condition2style_centroid_embedding_dict[condition] = style_centroid_embedding + return condition2style_centroid_file_dict, condition2style_centroid_embedding_dict + +def load_U2S_config(model_config_file): + hps = utils.get_hparams_from_file(os.path.join(os.path.dirname(__file__), model_config_file)) + from text.symbols import symbols_with_4096 as symbols + hps.num_symbols = len(symbols) + return hps + +def load_U2S_model(model_config_file, model_checkpoint_file, unit_type, ): + # load model + hps = utils.get_hparams_from_file(os.path.join(os.path.dirname(__file__), model_config_file)) + from text.symbols import symbols_with_4096 as symbols + net_g = SynthesizerTrn( + len(symbols), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model) + net_g.eval() + utils.load_checkpoint(os.path.join(os.path.dirname(__file__), model_checkpoint_file), net_g, None) + return net_g, hps + +def synthesis(unit_sequence, style_embedding, hps, net_g, output_wav_file='output.wav'): + # synthesize speech + device = next(net_g.parameters()).device # we assume speech tokenizer is stored in a single device + logging.info("Generating audios on {}".format(device)) + with torch.no_grad(): + unit_sequence = text_to_sequence(unit_sequence, hps.data.text_cleaners) + unit_sequence = torch.LongTensor(unit_sequence) + unit_sequence = unit_sequence.unsqueeze(0).to(device) + unit_lengths = torch.LongTensor([unit_sequence.size(1)]).to(device) + if style_embedding is not None: + style_embedding = style_embedding.to(device) + audio = net_g.synthesis_from_content_unit_style_embedding( + unit_sequence, unit_lengths, style_embedding, + noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0, 0].data.cpu().float().numpy() + write(output_wav_file, hps.data.sampling_rate, audio) + print(f'synthesized sample is saved as {output_wav_file}') + return audio + +if __name__ == "__main__": + ################# + # NPU + ################# + try: + import torch_npu + from torch_npu.npu import amp + from torch_npu.contrib import transfer_to_npu + print('Successful import torch_npu') + except Exception as e: + print(e) + + ############ + # S2U + ############ + reduced = True + reduced_mark = 'reduced' if reduced else 'unreduced' + unit_type = '40ms_multilingual_8888' + S2U_model_name = 'SPIRAL-FSQ-CTC' + + S2U_ckpt_path, S2U_config_path = get_S2U_ckpt_config_path(unit_type) + S2U_model = load_S2U_model(S2U_ckpt_path, S2U_config_path, S2U_model_name) + S2U_model = S2U_model.cuda() + + wav_file = "./examples/s2u/example.wav" + speech_unit = s2u_extract_unit_demo(S2U_model, wav_file, model_name=S2U_model_name, reduced=reduced) + print(speech_unit) + + ############ + # U2S + ############ + condition2style_centroid_file = "./speech_tokenization/condition_style_centroid/condition2style_centroid.txt" + condition2style_centroid_file_dict, condition2style_centroid_embedding_dict = load_condition_centroid(condition2style_centroid_file) + + unit_type = '40ms_multilingual_8888_xujing_cosyvoice_FT' + U2S_config_file, U2S_checkpoint_file = get_U2S_config_checkpoint_file(unit_type) + net_g, hps = load_U2S_model(U2S_config_file, U2S_checkpoint_file, unit_type) + net_g = net_g.cuda() + + content_unit = speech_unit.replace('<|speech_', '').replace('|>', ' ').strip() + emotion = random.choice(['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'sad', 'surprised']) + speed = random.choice(['normal', 'fast', 'slow']) + pitch = random.choice(['normal', 'high', 'low']) + gender = random.choice(['female', 'male']) + condition = f'gender-{gender}_emotion-{emotion}_speed-{speed}_pitch-{pitch}' + + style_centroid_file = condition2style_centroid_file_dict[condition] + style_centroid_embedding = condition2style_centroid_embedding_dict[condition] + + output_wav_file = f'./examples/u2s/{condition}_output.wav' + synthesis(content_unit, style_centroid_embedding, hps, net_g, output_wav_file) \ No newline at end of file diff --git a/MMaDA/models/training_utils.py b/MMaDA/models/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..954cf127d05d63b87350ea1d4f0bb3461a3cfa9e --- /dev/null +++ b/MMaDA/models/training_utils.py @@ -0,0 +1,455 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import random +from typing import Any, Dict, Iterable, Optional, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F + + +def enable_full_determinism(seed: int): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # set seed first + set_seed(seed) + + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 +class EMA: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + model_config: Dict[str, Any] = None, + **kwargs, + ): + """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA + weights will be stored on CPU. + + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + self.temp_stored_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + self.optimization_step = 0 + self.cur_decay_value = None # set in `step()` + + self.model_cls = model_cls + self.model_config = model_config + + @classmethod + def from_pretrained(cls, path, model_cls) -> "EMA": + _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) + model = model_cls.from_pretrained(path) + + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) + + ema_model.load_state_dict(ema_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") + + if self.model_config is None: + raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") + + model = self.model_cls.from_config(self.model_config) + state_dict = self.state_dict() + state_dict.pop("shadow_params", None) + + model.register_to_config(**state_dict) + self.copy_to(model.parameters()) + model.save_pretrained(path) + + def get_decay(self, optimization_step: int) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + self.cur_decay_value = decay + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + torch.cuda.empty_cache() + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.min_decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + } + + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: + affecting the original optimization process. Store the parameters before the `copy_to()` method. After + validation (or model saving), use this to restore the former parameters. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Args: + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + +# calculates entropy over each pixel distribution +def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): + # only calculated entropy over image tokens that were masked in the original image + masked_tokens = input_ids == mask_id + num_masked_pixels = masked_tokens.sum(-1) + + probs = F.softmax(logits, dim=-1) + log_probs = F.log_softmax(logits, dim=-1) + + entropy_per_pixel = -((probs * log_probs).sum(-1)) + + # the predictions for non-masked aren't used, so set their entropies to zero + entropy_per_pixel[~masked_tokens] = 0 + + entropy_per_image_numerator = entropy_per_pixel.sum(-1) + entropy_per_image = entropy_per_image_numerator / num_masked_pixels + + total_buckets = 10 + masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) + + entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets) + + return entropy_by_masked_bucket + + +# calculates entropy over the averaged distribution of pixels for the whole image +def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): + # only calculated entropy over image tokens that were masked in the original image + masked_tokens = input_ids == mask_id + num_masked_pixels = masked_tokens.sum(-1, keepdim=True) + + pixel_probs = F.softmax(logits, dim=-1) + pixel_probs[~masked_tokens] = 0 + image_probs_numerator = pixel_probs.sum(-2) + image_probs = image_probs_numerator / num_masked_pixels + + image_log_probs = image_probs.log() + + entropy_per_image = -((image_probs * image_log_probs).sum(-1)) + + total_buckets = 10 + masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) + + entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets) + + return entropy_by_masked_bucket + + +def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing): + cross_entropy_per_image = F.cross_entropy( + logits.view(-1, output_size), + labels.view(-1), + ignore_index=-100, + label_smoothing=label_smoothing, + reduction="none", + ) + + total_buckets = 10 + masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) + + cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets) + + return cross_entropy_by_percent_masked_bucket + + +def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id): + probs = F.softmax(logits, dim=-1) + + total_buckets = 10 + masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) + + data = [] + + for bucket_idx in range(total_buckets): + indices_for_bucket = masked_buckets[masked_buckets == bucket_idx] + + # It's ok if none were noised in the range of this bucket. This + # function will be called for a later training step where it's likely + # there will be an element noised in the range. + if indices_for_bucket.shape[0] == 0: + continue + + index_for_bucket = indices_for_bucket[0] + + image_probs = probs[index_for_bucket] + + # find the index of a masked pixel for the image + input_ids_for_image = input_ids[index_for_bucket] + masked_pixels_probs = image_probs[input_ids_for_image == mask_id] + + masked_pixel_probs = masked_pixels_probs[0] + + masked_pixel_probs = masked_pixel_probs.cpu().numpy() + + for masked_pixel_prob in masked_pixel_probs: + data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob}) + + df = pd.DataFrame(data) + + return df + + +def average_by_buckets(values, masked_buckets, total_buckets): + unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True) + + numerator = torch.zeros(total_buckets, device=values.device) + + numerator.scatter_add_(0, masked_buckets, values) + + # default value is one because the buckets for which there aren't + # any values will have a numerator of zero. So we just need to not divide + # by zero. + denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long) + denominator[unique_buckets] = bucket_counts + + averaged_by_buckets = numerator / denominator + + return averaged_by_buckets + + +def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10): + assert total_buckets == 10 + + masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1] + + # we do not formally use timesteps to noise images. Instead, we mask a percent + # of the pixels. We don't want to log entropy for every mask percent between 0 and 1, + # and we also want to track how the entropy evolves over time w/in a range of mask + # percents that should have similar entropy. So we bucket the masked percents into a + # fixed number of buckets + + # we could generalize this later if needed but for now, let's just assume a fixed + # number of 10 buckets. + + # How this maps to a bucket index: + # (mask) * bucket_index + + # (mask_1) * bucket_index_1 + # + # -> Where the mask is true will be set to the expected bucket index, + # where the mask is false will be set to 0. + # + # Given the probabilities are between 0 and 1, each masked_percent will get mapped + # to a timestep by one and only one of the masks. + + masked_buckets = ( + ((0 < masked_percent) & (masked_percent <= 0.1)) * 0 + + ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1 + + ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2 + + ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3 + + ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4 + + ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5 + + ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6 + + ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7 + + ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8 + + ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9 + ) + + return masked_buckets diff --git a/MMaDA/ommda-training-t2s-mmada/config.yaml b/MMaDA/ommda-training-t2s-mmada/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3867637c2f9bde3ae80749fcfd60404558d0643a --- /dev/null +++ b/MMaDA/ommda-training-t2s-mmada/config.yaml @@ -0,0 +1,83 @@ +wandb: + entity: null + resume: auto + run_id: rj1wt1s4 +experiment: + project: ommda-training-t2s + name: ommda-training-t2s-mmada + output_dir: ommda-training-t2s-mmada + save_every: 5000 + eval_every: 20000 + generate_every: 5000 + num_validation_images: 20 + log_every: 1 + log_grad_norm_every: 100 + resume_from_checkpoint: latest + val_every: 50000 + max_val_examples_t2i: 2000 + logging_dir: ommda-training-t2s-mmada/logs +model: + vq_model: + type: emova + vq_model_name: Emova-ollm/emova_speech_tokenizer_hf + mmada: + tokenizer_path: GSAI-ML/LLaDA-8B-Instruct + pretrained_model_path: Gen-Verse/MMaDA-8B-Base + w_clip_vit: false + new_vocab_size: 138752 + llm_vocab_size: 126464 + codebook_size: 8192 + speech_codebook_size: 4096 + num_new_special_tokens: 3 + tie_word_embeddings: false + gradient_checkpointing: true +dataset: + params: + num_workers: 0 + resolution: 256 + pin_memory: true + persistent_workers: true + preprocessing: + max_seq_length: 256 + resolution: 256 + center_crop: false + random_flip: false + data: + name: gigaspeech + subset: xl + split: train +optimizer: + name: adamw + params: + learning_rate: 0.0001 + scale_lr: false + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1.0e-08 +lr_scheduler: + scheduler: cosine + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 2500 + min_lr_scale: 0.1 +training: + gradient_accumulation_steps: 4 + noise_type: mask + batch_size_s2t: 4 + mixed_precision: bf16 + enable_tf32: true + seed: 10086 + max_train_steps: 50000 + overfit_one_batch: false + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: 1 + guidance_scale: 5 + generation_timesteps: 50 + t2i_coeff: 1.0 + lm_coeff: 0.1 + mmu_coeff: 0.5 + validation_seed: 42 +config: /home/work/AIDAS/MMaDA/configs/mmada_pretraining_t2s.yaml diff --git a/MMaDA/parquet/__init__.py b/MMaDA/parquet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9130922d05681c097a4e829564d1558c967c1efe --- /dev/null +++ b/MMaDA/parquet/__init__.py @@ -0,0 +1,2 @@ +# from .refinedweb_dataset import RefinedWebDataset +from .my_dataset import RefinedWebDataset, ChatDataset, VQADataset \ No newline at end of file diff --git a/MMaDA/parquet/my_dataset.py b/MMaDA/parquet/my_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..01299a8e3bc37727e2792a7541199f8605f2f658 --- /dev/null +++ b/MMaDA/parquet/my_dataset.py @@ -0,0 +1,489 @@ +import collections +import os +import random +import torch +from torch.utils.data import IterableDataset, DataLoader +import pandas as pd +import glob +from typing import List, Dict, Any, Optional, Iterator +import pyarrow.parquet as pq +from transformers import AutoTokenizer +from torchvision import transforms +import json +from PIL import Image + +class RefinedWebDataset(IterableDataset): + def __init__(self, + data_path, + rank: int = 0, + world_size: int = 1, + shuffle=True, + repeat=True, + buffer_size=1000, + max_length=8000, + num_workers=1): + super().__init__() + self.files = sorted(glob.glob(data_path)) + self.rank = rank + self.world_size = world_size + self.shuffle = shuffle + self.repeat = repeat + self.buffer_size = buffer_size + self.max_length = max_length + self.num_workers = num_workers + + self.files = self.files[self.rank::self.world_size] + + def read_parquet_file(self, file_path): + table = pq.read_table(file_path, columns=["content"]) + df = table.to_pandas() + for _, row in df.iterrows(): + yield {"content": row["content"]} + + def __iter__(self): + while True: + file_list = self.files + if self.shuffle: + random.shuffle(file_list) + + for file in file_list: + data_generator = self.read_parquet_file(file) + buffer = [] + + for data in data_generator: + text = data["content"].replace("\n", "") + if len(text) > self.max_length: + start_index = random.randint(0, len(text) - self.max_length - 1) + selected_text = text[start_index:start_index + self.max_length] + else: + selected_text = text + + buffer.append({"input_ids": selected_text}) + + if len(buffer) >= self.buffer_size: + if self.shuffle: + random.shuffle(buffer) + for item in buffer: + yield item + buffer = [] + + if buffer: + if self.shuffle: + random.shuffle(buffer) + for item in buffer: + yield item + + if not self.repeat: + break + + def collate_fn(self, batch): + batched = collections.defaultdict(list) + for data in batch: + for k, v in data.items(): + batched[k].append(v) + + for k, v in batched.items(): + if k not in ('key', 'input_ids', 'similarity'): + batched[k] = torch.stack(v, dim=0) + + return batched + +class ChatDataset(IterableDataset): + def __init__(self, + data_path, + rank: int = 0, + world_size: int = 1, + shuffle=True, + repeat=True, + buffer_size=1000, + max_length=8000, + num_workers=1, + tokenizer=None): + super().__init__() + self.files = sorted(glob.glob(data_path)) + self.rank = rank + self.world_size = world_size + self.shuffle = shuffle + self.repeat = repeat + self.buffer_size = buffer_size + self.max_length = max_length + self.num_workers = num_workers + self.tokenizer = tokenizer + + self.files = self.files[self.rank::self.world_size] + + def read_parquet_file(self, file_path): + table = pq.read_table(file_path, columns=["content"]) + df = table.to_pandas() + for _, row in df.iterrows(): + yield {"content": row["content"]} + + def __iter__(self): + while True: + file_list = self.files + if self.shuffle: + random.shuffle(file_list) + + for file in file_list: + data_generator = self.read_parquet_file(file) + buffer = [] + + for data in data_generator: + text = data["content"] + if self.tokenizer is None: + if len(text) > self.max_length: + start_index = random.randint(0, len(text) - self.max_length - 1) + selected_text = text[start_index:start_index + self.max_length] + else: + selected_text = text + else: + if len(self.tokenizer(text)['input_ids']) < self.max_length: + selected_text = text + else: + continue + + buffer.append({"input_ids": selected_text}) + + if len(buffer) >= self.buffer_size: + if self.shuffle: + random.shuffle(buffer) + for item in buffer: + yield item + buffer = [] + + if buffer: + if self.shuffle: + random.shuffle(buffer) + for item in buffer: + yield item + + if not self.repeat: + break + + def collate_fn(self, batch): + batched = collections.defaultdict(list) + for data in batch: + for k, v in data.items(): + batched[k].append(v) + + for k, v in batched.items(): + if k not in ('key', 'input_ids', 'similarity'): + batched[k] = torch.stack(v, dim=0) + + return batched + +class R2iDataset(IterableDataset): + def __init__(self, + data_path, + rank: int = 0, + world_size: int = 1, + shuffle=True, + repeat=True, + buffer_size=1000, + max_length=8000, + num_workers=1, + resolution=256, + tokenizer=None): + super().__init__() + self.data_path = data_path + self.rank = rank + self.world_size = world_size + self.shuffle = shuffle + self.repeat = repeat + self.buffer_size = buffer_size + self.max_length = max_length + self.num_workers = num_workers + self.tokenizer = tokenizer + self.resolution = resolution + + def __iter__(self): + while True: + subdirs = sorted([d for d in glob.glob(os.path.join(self.data_path, "*")) if os.path.isdir(d)]) + + if self.shuffle: + random.shuffle(subdirs) + + subdirs = subdirs[self.rank::self.world_size] + + subdirs = ['/data_storage/lbw/datasets/laion-aesthetics-12m-images-2/00000'] + + for subdir in subdirs: + all_files = glob.glob(os.path.join(subdir, "*.*")) + base_names = set() + + for file_path in all_files: + base_name = os.path.splitext(os.path.basename(file_path))[0] + base_names.add(base_name) + + base_names = list(base_names) + if self.shuffle: + random.shuffle(base_names) + + buffer = [] + + for base_name in base_names: + jpg_path = os.path.join(subdir, f"{base_name}.jpg") + caption_path = os.path.join(subdir, f"{base_name}.caption") + shortcaption_path = os.path.join(subdir, f"{base_name}.shortcaption") + + if not os.path.exists(jpg_path): + continue + + try: + image = Image.open(jpg_path).convert("RGB") + + caption = "" + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + + short_caption = "" + if os.path.exists(shortcaption_path): + with open(shortcaption_path, "r", encoding="utf-8") as f: + short_caption = f.read().strip() + + transformed_image = image_transform_clip({"images": image}, resolution=self.resolution)["images"] + + if self.tokenizer is not None: + if len(self.tokenizer(caption)['input_ids']) > self.max_length - 2: + continue + + prompt = ( + '<|start_header_id|>user<|end_header_id|>\n' + "You should first think out a more detailed version of the description and then provide the user with the image. The detailed description is enclosed within tags, i.e. detailed description here image here\n" + f"{short_caption}" + '<|start_header_id|>assistant<|end_header_id|>\n' + f"{caption}" + ) + + sample = { + "images": transformed_image, + "input_ids": prompt, + } + + buffer.append(sample) + + if len(buffer) >= self.buffer_size: + if self.shuffle: + random.shuffle(buffer) + for item in buffer: + yield item + buffer = [] + + except Exception as e: + print(f"Error processing {jpg_path}: {e}") + continue + + if buffer: + if self.shuffle: + random.shuffle(buffer) + for item in buffer: + yield item + + if not self.repeat: + break + + def collate_fn(self, batch): + batched = collections.defaultdict(list) + for data in batch: + for k, v in data.items(): + batched[k].append(v) + + for k, v in batched.items(): + if k not in ('key', 'input_ids', 'similarity'): + batched[k] = torch.stack(v, dim=0) + + return batched + +class VQADataset(IterableDataset): + def __init__(self, + json_path: str, + image_root: str, + tokenizer = None, + rank: int = 0, + world_size: int = 1, + shuffle: bool = True, + repeat: bool = True, + buffer_size: int = 100, + resolution: int = 256, + max_length: int = 8000, + num_workers: int = 1, + image_transform_method: str = "squash"): + super().__init__() + self.json_path = json_path + self.image_root = image_root + self.tokenizer = tokenizer + self.rank = rank + self.world_size = world_size + self.shuffle = shuffle + self.repeat = repeat + self.buffer_size = buffer_size + self.resolution = resolution + self.max_length = max_length + self.num_workers = num_workers + self.image_transform_method = image_transform_method + try: + with open(self.json_path, 'r', encoding='utf-8') as f: + raw_data = json.load(f) + except FileNotFoundError: + print(f"Error: Data file not found at {self.json_path}") + self.list_data_dict = [] + except json.JSONDecodeError: + print(f"Error: Could not decode JSON from {self.json_path}") + self.list_data_dict = [] + else: + self.list_data_dict = [item for item in raw_data if 'image' in item and 'conversations' in item] + self.list_data_dict = self.list_data_dict[self.rank::self.world_size] + def __iter__(self): + sot_token = '<|startoftext|>' + assistant_prompt_suffix = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + while True: + current_data_list = list(self.list_data_dict) + if self.shuffle: + random.shuffle(current_data_list) + buffer = [] + for item in current_data_list: + image_relative_path = item.get('image') + conversations = item.get('conversations', []) + if not image_relative_path or not conversations or len(conversations) < 2: + continue + num_total_messages = len(conversations) + if num_total_messages % 2 != 0: + conversations = conversations[:-1] + num_total_messages -= 1 + if num_total_messages < 2: continue + num_turns = num_total_messages // 2 + if num_turns == 0: + continue + selected_num_turns = random.randint(1, num_turns) + selected_conversations = conversations[:selected_num_turns * 2] + image_path = os.path.join(self.image_root, image_relative_path) + try: + image = Image.open(image_path).convert("RGB") + if self.image_transform_method == "squash": + transformed_image = image_transform_squash({"images": image}, resolution=self.resolution)["images"] + elif self.image_transform_method == "pad": + transformed_image = image_transform_pad({"images": image}, resolution=self.resolution)["images"] + else: + transformed_image = image_transform_clip({"images": image}, resolution=self.resolution)["images"] + first_human_message = selected_conversations[0]['value'] + processed_message = first_human_message.replace('\n', '').replace('\n', '') + current_selection_messages = list(selected_conversations) + current_selection_messages[0] = dict(current_selection_messages[0]) + current_selection_messages[0]['value'] = processed_message + messages = [] + for turn in current_selection_messages: + role = "user" if turn["from"] == "human" else "assistant" + messages.append({"role": role, "content": turn["value"]}) + formatted_text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + if formatted_text.startswith(sot_token): + formatted_text = formatted_text[len(sot_token):] + if formatted_text.endswith(assistant_prompt_suffix): + formatted_text = formatted_text[:-len(assistant_prompt_suffix)] + token_ids = self.tokenizer(formatted_text)['input_ids'] + if len(token_ids) > self.max_length: + continue + sample = { + "images": transformed_image, + "input_ids": formatted_text, + } + buffer.append(sample) + if len(buffer) >= self.buffer_size: + if self.shuffle: + random.shuffle(buffer) + for buf_item in buffer: + yield buf_item + buffer = [] + except FileNotFoundError: + print(f"Warning: Image file not found at {image_path}, skipping item.") + continue + except Exception as e: + print(f"Warning: Error processing item with image {image_path}: {e}, skipping.") + continue + if buffer: + if self.shuffle: + random.shuffle(buffer) + for buf_item in buffer: + yield buf_item + if not self.repeat: + break + def collate_fn(self, batch): + batched = collections.defaultdict(list) + for data in batch: + for k, v in data.items(): + batched[k].append(v) + for k, v in batched.items(): + if k not in ('key', 'input_ids', 'similarity'): + batched[k] = torch.stack(v, dim=0) + return batched + +def image_transform_clip(sample, resolution=256): + image = sample["images"] + image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image) + image = transforms.CenterCrop((resolution, resolution))(image) + image = transforms.ToTensor()(image) + image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) + sample["images"] = image + return sample + +def image_transform_squash(sample, resolution=256): + image = sample["images"] + image = transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC)(image) + image = transforms.ToTensor()(image) + image = transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])(image) + sample["images"] = image + return sample + +def image_transform_pad(sample, resolution=256, fill_color=(255, 255, 255)): + image = sample["images"] + w, h = image.size + if w == h: + padded_image = image + elif w < h: + padding_needed = h - w + padding_left = padding_needed // 2 + padding_right = padding_needed - padding_left + pad_transform = transforms.Pad((padding_left, 0, padding_right, 0), fill=fill_color, padding_mode='constant') + padded_image = pad_transform(image) + else: + padding_needed = w - h + padding_top = padding_needed // 2 + padding_bottom = padding_needed - padding_top + pad_transform = transforms.Pad((0, padding_top, 0, padding_bottom), fill=fill_color, padding_mode='constant') + padded_image = pad_transform(image) + image_resized = transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC)(padded_image) + image_tensor = transforms.ToTensor()(image_resized) + image_normalized = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_tensor) + sample["images"] = image_normalized + return sample + +if __name__ == '__main__': + data_path = "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" + dataset = RefinedWebDataset( + data_path=data_path, + max_length=8000, + buffer_size=0, + ) + + from torch.utils.data import DataLoader + train_dataloader = DataLoader( + dataset, + batch_size=1, + sampler=None, + collate_fn=dataset.collate_fn, + num_workers=0 + ) + + print("Starting data loading test...") + for i, batch in enumerate(train_dataloader): + if i == 0: + print(batch) + print(f"Batch size: {len(batch['input_ids'])}") + print(f"First sample length: {len(batch['input_ids'][0])}") + if i >= 5: + break + print("Data loading test complete") \ No newline at end of file diff --git a/MMaDA/precompute_instructs2s_tokens.py b/MMaDA/precompute_instructs2s_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6b95dcc20b0273d65fb8f4ed6c192f52c38f9e --- /dev/null +++ b/MMaDA/precompute_instructs2s_tokens.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +Pre-compute EMOVA speech tokenizer codes for InstructS2S (ė˜ėŠ” źø°ķƒ€ ė‹Øģ¼ ģ˜¤ė””ģ˜¤ ķ“ė”). + +ģ˜ˆģ‹œ: + python /home/work/AIDAS/MMaDA/precompute_instructs2s_tokens.py \ + --audio-root /home/work/AIDAS/data/InstructS2S-200K/en/wav \ + --output-root /home/work/AIDAS/cache/instructs2s_tokens \ + --pairs-file /home/work/AIDAS/data/InstructS2S-200K/en/wav/pairs.txt + +sha1(ģ ˆėŒ€ź²½ė”œ) 기반 ģŗģ‹œ 구씰넼 ģ‚¬ģš©ķ•˜ėÆ€ė”œ, ķ•™ģŠµ ģ½”ė“œģ—ģ„œ źø°ėŒ€ķ•˜ėŠ” 디렉터리 +(`MixedSpeechTextDataset`, `Speech2SpeechDataset`)와 ė™ģ¼ķ•˜ź²Œ ė™ģž‘ķ•©ė‹ˆė‹¤. +""" + +from __future__ import annotations + +import argparse +import hashlib +import os +import tempfile +from pathlib import Path +from typing import Iterable, Iterator, Optional, Sequence + +import soundfile as sf +import torch +from tqdm import tqdm + +# Ensure project root on path +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in os.sys.path: + os.sys.path.append(str(REPO_ROOT)) + +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer # noqa: E402 + + +def iter_instructs2s_audio( + audio_root: Path, pairs_file: Optional[Path] = None +) -> Iterator[Path]: + """ + InstructS2S 루트넼 ģˆœķšŒķ•˜ė©° user/assistant wav 경딜넼 모두 yield ķ•©ė‹ˆė‹¤. + + `pairs.txt`ź°€ 제공되멓 ź°€ģž„ ģš°ģ„ ģœ¼ė”œ ģ‚¬ģš©ķ•˜ź³ , ģ—†ģœ¼ė©“ 디렉터리 구씰넼 ģˆœķšŒķ•©ė‹ˆė‹¤. + """ + resolved_root = audio_root.expanduser().resolve() + pairs_candidate = pairs_file + if pairs_candidate is None: + candidate = resolved_root / "pairs.txt" + if candidate.exists(): + pairs_candidate = candidate + + if pairs_candidate is not None: + with pairs_candidate.open("r") as fh: + for line in fh: + line = line.strip() + if not line: + continue + parts = line.split() + if len(parts) < 2: + continue + user_path = Path(parts[0]) + if not user_path.is_absolute(): + user_path = resolved_root / user_path + assistant_path = Path(parts[1]) + if not assistant_path.is_absolute(): + assistant_path = resolved_root / assistant_path + if user_path.is_file(): + yield user_path + if assistant_path.is_file(): + yield assistant_path + return + + # pairs.txtź°€ ģ—†ģœ¼ė©“ 디렉터리 순회 + for dir_path in sorted(resolved_root.iterdir()): + if not dir_path.is_dir(): + continue + dir_name = dir_path.name + k = 1 + while True: + user_wav = dir_path / f"{dir_name}-{k}-user.wav" + assistant_wav = dir_path / f"{dir_name}-{k}-assistant.wav" + if user_wav.is_file() and assistant_wav.is_file(): + yield user_wav + yield assistant_wav + k += 1 + continue + break + + +def hash_path(path: Path) -> str: + """ģ ˆėŒ€ 경딜넼 sha1으딜 ķ•“ģ‹œķ•œ 40źø€ģž hex ė°˜ķ™˜.""" + abs_path = os.path.abspath(path) + return hashlib.sha1(abs_path.encode("utf-8")).hexdigest() + + +def token_output_path(output_root: Path, audio_path: Path) -> Path: + digest = hash_path(audio_path) + return output_root / digest[:2] / digest[2:4] / f"{digest}.pt" + + +def encode_audio(tokenizer: EMOVASpeechTokenizer, audio_path: Path) -> torch.Tensor: + """ + EMOVA ķ† ķ¬ė‚˜ģ“ģ €ė”œ ģ˜¤ė””ģ˜¤ė„¼ 토큰화. + 비-WAV ķ¬ė§·ģ€ ģž„ģ‹œ ķŒŒģ¼ė”œ ė³€ķ™˜ 후 ģ²˜ė¦¬ķ•©ė‹ˆė‹¤. + """ + suffix = audio_path.suffix.lower() + if suffix == ".wav": + return tokenizer.encode(str(audio_path)).cpu() + + data, sample_rate = sf.read(str(audio_path)) + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + try: + sf.write(tmp.name, data, sample_rate) + tokens = tokenizer.encode(tmp.name).cpu() + finally: + tmp.close() + try: + os.remove(tmp.name) + except OSError: + pass + return tokens + + +def gather_audio_paths(audio_root: Path, pairs_file: Optional[Path]) -> list[Path]: + paths = list(iter_instructs2s_audio(audio_root, pairs_file)) + # 중복 제거 + seen = set() + unique: list[Path] = [] + for path in paths: + if path not in seen: + seen.add(path) + unique.append(path) + return unique + + +def main(): + parser = argparse.ArgumentParser(description="Pre-compute EMOVA speech tokens for InstructS2S.") + parser.add_argument( + "--audio-root", + type=Path, + default=Path("/home/work/AIDAS/data/InstructS2S-200K/en/wav"), + help="user/assistant WAVź°€ ģœ„ģ¹˜ķ•œ 루트 디렉터리", + ) + parser.add_argument( + "--pairs-file", + type=Path, + default=None, + help="ģ„ ķƒ 사항: pairs.txt 경딜 (ģ§€ģ •ķ•˜ģ§€ ģ•Šģœ¼ė©“ audio-root/pairs.txt ķƒģƒ‰)", + ) + parser.add_argument( + "--output-root", + type=Path, + default=Path("/home/work/AIDAS/cache/instructs2s_tokens"), + help="ķ† ķ°ģ„ ģ €ģž„ķ•  디렉터리", + ) + parser.add_argument( + "--tokenizer", + type=str, + default="Emova-ollm/emova_speech_tokenizer_hf", + help="EMOVA speech tokenizer ģ²“ķ¬ķ¬ģøķŠø", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="ģøģ½”ė”©ģ— ģ‚¬ģš©ķ•  ė””ė°”ģ“ģŠ¤", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="ģ“ėÆø ģ”“ģž¬ķ•˜ėŠ” ķ† ķ°ģ„ ė‹¤ģ‹œ ź³„ģ‚°ķ•©ė‹ˆė‹¤.", + ) + args = parser.parse_args() + + audio_root = args.audio_root.expanduser().resolve() + if not audio_root.exists(): + parser.error(f"Audio root not found: {audio_root}") + + pairs_file = args.pairs_file.expanduser().resolve() if args.pairs_file else None + if pairs_file is not None and not pairs_file.exists(): + parser.error(f"pairs-file not found: {pairs_file}") + + output_root = args.output_root.expanduser().resolve() + output_root.mkdir(parents=True, exist_ok=True) + + audio_paths = gather_audio_paths(audio_root, pairs_file) + if not audio_paths: + print("No audio files found. Nothing to encode.") + return + + device = torch.device(args.device) + if device.type == "cuda": + torch.cuda.set_device(device) + tokenizer = EMOVASpeechTokenizer.from_pretrained(args.tokenizer).to(device) + tokenizer.eval() + + total = 0 + skipped = 0 + failed: list[Path] = [] + + for audio_path in tqdm(audio_paths, desc="Encoding InstructS2S clips"): + audio_path = audio_path.expanduser().resolve() + out_path = token_output_path(output_root, audio_path) + if out_path.exists() and not args.overwrite: + skipped += 1 + continue + + out_path.parent.mkdir(parents=True, exist_ok=True) + try: + tokens = encode_audio(tokenizer, audio_path) + except Exception as exc: + tqdm.write(f"[WARN] Failed to encode {audio_path}: {exc}") + failed.append(audio_path) + continue + + tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") + torch.save(tokens, tmp_path) + os.replace(tmp_path, out_path) + total += 1 + + if failed: + failed_log = output_root / "failed_paths.log" + with failed_log.open("a") as fh: + for path in failed: + fh.write(f"{path}\n") + print(f"Failed to encode {len(failed)} files. See {failed_log}") + + print(f"Done. Encoded {total} files. Skipped {skipped} existing entries.") + + +if __name__ == "__main__": + main() diff --git a/MMaDA/precompute_video_speech_tokens.py b/MMaDA/precompute_video_speech_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4d4c13a34d95ea74cf2567533f1dec12837c6c --- /dev/null +++ b/MMaDA/precompute_video_speech_tokens.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python3 +""" +Pre-compute EMOVA speech tokenizer codes for audio datasets. + +Supported dataset types: + - video-speech : CSV index with truncated WAV clips (e.g., OpenVid speech) + - librispeech : LibriSpeech directory structure with FLAC audio + - instructs2s : InstructS2S-200K style user/assistant WAV pairs + +Examples +-------- + # VideoSpeech + python MMaDA/precompute_video_speech_tokens.py \\ + --dataset-type video-speech \\ + --index /home/work/AIDAS/data/video-speech/openvid-speech.csv \\ + --audio-root /home/work/AIDAS/data/video-speech/openvid-speech-trunc \\ + --output /home/work/AIDAS/cache/video_speech_tokens + + # LibriSpeech + python MMaDA/precompute_video_speech_tokens.py \\ + --dataset-type librispeech \\ + --audio-root /home/work/AIDAS/data/audio/LibriSpeech \\ + --librispeech-subsets train-clean-360 train-clean-100 \\ + --output /home/work/AIDAS/cache/librispeech_tokens + + # InstructS2S (pairs.txt assumed under audio root) + python MMaDA/precompute_video_speech_tokens.py \\ + --dataset-type instructs2s \\ + --audio-root /home/work/AIDAS/data/InstructS2S-200K/en/wav \\ + --output /home/work/AIDAS/cache/instructs2s_tokens +""" + +import argparse +import csv +import hashlib +import os +import sys +import tempfile +from pathlib import Path +from typing import Iterable, Iterator, List, Set + +import soundfile as sf +import torch +from tqdm import tqdm + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer # noqa: E402 + + +def iter_video_speech_audio(index_path: Path, audio_root: Path) -> Iterator[Path]: + """ + Yields audio paths from the VideoSpeech index CSV. + """ + with index_path.open("r", newline="") as csvfile: + reader = csv.reader(csvfile) + for row in reader: + if not row: + continue + base = row[0].strip().removesuffix(".wav") + if not base: + continue + audio_path = audio_root / f"{base}.wav" + if audio_path.is_file(): + yield audio_path + + +def iter_librispeech_audio(audio_root: Path, subsets: Iterable[str]) -> Iterator[Path]: + """ + Iterates through LibriSpeech FLAC files for the provided subsets. + """ + for subset in subsets: + subset_dir = audio_root / subset + if not subset_dir.exists(): + raise FileNotFoundError(f"LibriSpeech subset not found: {subset_dir}") + speakers = sorted(p for p in subset_dir.iterdir() if p.is_dir()) + for speaker_dir in speakers: + chapters = sorted(p for p in speaker_dir.iterdir() if p.is_dir()) + for chapter_dir in chapters: + for flac_path in sorted(chapter_dir.glob("*.flac")): + yield flac_path + + +def iter_instructs2s_audio(audio_root: Path, pairs_file: Path | None = None) -> Iterator[Path]: + """ + Yields unique audio paths from an InstructS2S root directory. + If pairs_file is provided (or found under audio_root), it's expected to contain + two space-separated paths per line: user assistant. + Otherwise, the directory tree is scanned similarly to Speech2SpeechDataset. + """ + resolved_root = audio_root.expanduser().resolve() + if pairs_file is None: + candidate = resolved_root / "pairs.txt" + if candidate.exists(): + pairs_file = candidate + if pairs_file is not None: + with Path(pairs_file).open("r") as fh: + for line in fh: + line = line.strip() + if not line: + continue + parts = line.split() + if len(parts) >= 2: + user_path = Path(parts[0]) + if not user_path.is_absolute(): + user_path = resolved_root / user_path + asst_path = Path(parts[1]) + if not asst_path.is_absolute(): + asst_path = resolved_root / asst_path + if user_path.is_file(): + yield user_path + if asst_path.is_file(): + yield asst_path + return + + dirs = [p for p in resolved_root.glob("*") if p.is_dir()] + for dir_path in dirs: + dir_name = dir_path.name + k = 1 + while True: + user_wav = dir_path / f"{dir_name}-{k}-user.wav" + assistant_wav = dir_path / f"{dir_name}-{k}-assistant.wav" + if user_wav.is_file() and assistant_wav.is_file(): + yield user_wav + yield assistant_wav + k += 1 + continue + break + + +def hash_path(path: Path) -> str: + """Returns a SHA-1 hex digest for the absolute path.""" + abs_path = os.path.abspath(path) + return hashlib.sha1(abs_path.encode("utf-8")).hexdigest() + + +def token_output_path(output_root: Path, audio_path: Path) -> Path: + """Resolves the on-disk location for cached tokens corresponding to audio_path.""" + digest = hash_path(audio_path) + return output_root / digest[:2] / digest[2:4] / f"{digest}.pt" + + +def encode_audio(tokenizer: EMOVASpeechTokenizer, audio_path: Path) -> torch.Tensor: + """ + Encodes an audio file to discrete tokens, converting non-WAV inputs on the fly. + """ + suffix = audio_path.suffix.lower() + if suffix == ".wav": + return tokenizer.encode(str(audio_path)).cpu() + + data, sample_rate = sf.read(str(audio_path)) + tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + try: + sf.write(tmp_file.name, data, sample_rate) + tokens = tokenizer.encode(tmp_file.name).cpu() + finally: + tmp_file.close() + try: + os.remove(tmp_file.name) + except OSError: + pass + return tokens + + +def gather_audio_paths(args) -> List[Path]: + if args.dataset_type == "video-speech": + return list(iter_video_speech_audio(args.index, args.audio_root)) + if args.dataset_type == "librispeech": + return list(iter_librispeech_audio(args.audio_root, args.librispeech_subsets)) + # instructs2s + paths = list(iter_instructs2s_audio(args.audio_root, args.pairs_file)) + # Deduplicate while preserving order + seen: Set[Path] = set() + unique_paths: List[Path] = [] + for path in paths: + if path not in seen: + seen.add(path) + unique_paths.append(path) + return unique_paths + + +def split_into_shards(items: List[Path], shard_count: int) -> List[List[Path]]: # pragma: no cover - simple helper + shard_count = max(1, shard_count) + shard_size = (len(items) + shard_count - 1) // shard_count + return [items[i * shard_size : (i + 1) * shard_size] for i in range(shard_count)] + + +def process_shard( + shard_id: int, + audio_paths: List[Path], + device: str, + tokenizer_name: str, + output_root: Path, + overwrite: bool, + dataset_type: str, +) -> tuple[int, int, List[Path]]: + if not audio_paths: + return 0, 0, [] + + device_obj = torch.device(device) + if device_obj.type == "cuda": + torch.cuda.set_device(device_obj) + tokenizer = EMOVASpeechTokenizer.from_pretrained(tokenizer_name).to(device_obj) + tokenizer.eval() + + total = 0 + skipped = 0 + desc = f"{dataset_type} worker {shard_id}" + failed_paths: List[Path] = [] + for audio_path in tqdm(audio_paths, desc=desc, position=shard_id, leave=False): + out_path = token_output_path(output_root, audio_path) + if out_path.exists() and not overwrite: + skipped += 1 + continue + out_path.parent.mkdir(parents=True, exist_ok=True) + try: + tokens = encode_audio(tokenizer, audio_path) + except Exception as exc: # pragma: no cover - runtime diagnostics + tqdm.write(f"[WARN][worker {shard_id}] Failed to encode {audio_path}: {exc}") + failed_paths.append(audio_path) + continue + tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") + torch.save(tokens, tmp_path) + os.replace(tmp_path, out_path) + total += 1 + return total, skipped, failed_paths + + +def main(): + parser = argparse.ArgumentParser(description="Pre-compute speech tokens for audio datasets.") + parser.add_argument( + "--dataset-type", + "--dataset_type", + dest="dataset_type", + choices=["video-speech", "librispeech", "instructs2s"], + default="video-speech", + help="Dataset type to process.", + ) + parser.add_argument( + "--index", + type=Path, + help="CSV index for video-speech datasets (required for dataset-type=video-speech).", + ) + parser.add_argument( + "--audio-root", + type=Path, + required=True, + help="Root directory containing audio files. For LibriSpeech this should be the LibriSpeech root.", + ) + parser.add_argument( + "--librispeech_subsets", + nargs="+", + default=None, + help="LibriSpeech subsets to process (e.g., train-clean-360). Required when dataset-type=librispeech.", + ) + parser.add_argument( + "--pairs-file", + "--pairs_file", + type=Path, + default=None, + help="Optional pairs.txt to use for instructs2s dataset.", + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Directory to store the precomputed token tensors.", + ) + parser.add_argument( + "--tokenizer", + type=str, + default="Emova-ollm/emova_speech_tokenizer_hf", + help="Name or path of the EMOVA speech tokenizer checkpoint to use.", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Recompute and overwrite existing token files.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device for running the tokenizer encoder.", + ) + parser.add_argument( + "--devices", + nargs="+", + default=None, + help="Optional list of devices per worker (e.g., cuda:0 cuda:1 ...). Overrides --device/--num-workers.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=1, + help="Number of parallel workers (ignored if --devices is provided).", + ) + args = parser.parse_args() + + if args.index is not None: + args.index = args.index.expanduser().resolve() + if not args.index.exists(): + parser.error(f"Index file not found: {args.index}") + + if args.dataset_type == "video-speech" and args.index is None: + parser.error("--index is required when dataset-type=video-speech.") + if args.dataset_type == "librispeech" and not args.librispeech_subsets: + parser.error("--librispeech-subsets must be provided when dataset-type=librispeech.") + + args.audio_root = args.audio_root.expanduser().resolve() + args.output = args.output.expanduser().resolve() + if args.pairs_file is not None: + args.pairs_file = Path(args.pairs_file).expanduser().resolve() + if not args.pairs_file.exists(): + parser.error(f"pairs-file not found: {args.pairs_file}") + + args.output.mkdir(parents=True, exist_ok=True) + + audio_paths = gather_audio_paths(args) + if not audio_paths: + print("No audio files found. Nothing to encode.") + return + + if args.devices: + worker_devices = args.devices + else: + worker_devices = [args.device] * max(1, args.num_workers) + + if len(worker_devices) == 1: + device = torch.device(worker_devices[0]) + tokenizer = EMOVASpeechTokenizer.from_pretrained(args.tokenizer).to(device) + tokenizer.eval() + + total = 0 + skipped = 0 + failed_paths: List[Path] = [] + for audio_path in tqdm(audio_paths, desc="Encoding clips"): + out_path = token_output_path(args.output, audio_path) + if out_path.exists() and not args.overwrite: + skipped += 1 + continue + + out_path.parent.mkdir(parents=True, exist_ok=True) + try: + tokens = encode_audio(tokenizer, audio_path) + except Exception as exc: + tqdm.write(f"[WARN] Failed to encode {audio_path}: {exc}") + failed_paths.append(audio_path) + continue + + tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") + torch.save(tokens, tmp_path) + os.replace(tmp_path, out_path) + total += 1 + if failed_paths: + failed_log = args.output / "failed_paths.log" + with failed_log.open("a") as fh: + for path in failed_paths: + fh.write(f"{path}\n") + print(f"Wrote {len(failed_paths)} failed paths to {failed_log}") + print(f"Done. Encoded {total} clips. Skipped {skipped} existing entries.") + return + + shards = split_into_shards(audio_paths, len(worker_devices)) + from multiprocessing import get_context + + ctx = get_context("spawn") + futures = [] + with ctx.Pool(len(worker_devices)) as pool: + for shard_id, (device_str, shard_paths) in enumerate(zip(worker_devices, shards)): + futures.append( + pool.apply_async( + process_shard, + ( + shard_id, + shard_paths, + device_str, + args.tokenizer, + args.output, + args.overwrite, + args.dataset_type, + ), + ) + ) + pool.close() + pool.join() + + total = 0 + skipped = 0 + failed_paths: List[Path] = [] + for fut in futures: + shard_total, shard_skipped, shard_failed = fut.get() + total += shard_total + skipped += shard_skipped + failed_paths.extend(shard_failed) + + if failed_paths: + failed_log = args.output / "failed_paths.log" + with failed_log.open("a") as fh: + for path in failed_paths: + fh.write(f"{path}\n") + print(f"Wrote {len(failed_paths)} failed paths to {failed_log}") + + print(f"Done. Encoded {total} clips. Skipped {skipped} existing entries.") + + +if __name__ == "__main__": + main() diff --git a/MMaDA/requirements.txt b/MMaDA/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..209707193c7cbbcea152e93ef09b009b923c29d6 --- /dev/null +++ b/MMaDA/requirements.txt @@ -0,0 +1,19 @@ +transformers==4.46.0 +wandb +webdataset +omegaconf +pyarrow +pandas +jaxtyping +diffusers==0.32.2 +typeguard +accelerate +deepspeed +lightning +datasets +image-reward +torchmetrics==1.6.2 +clip @ git+https://github.com/openai/CLIP.git +trl==0.17.0 +math_verify +gradio>=4.44.1 \ No newline at end of file diff --git a/MMaDA/script/check_audio.py b/MMaDA/script/check_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..7a478052544f4749f237f9f3e2221329bfb651e6 --- /dev/null +++ b/MMaDA/script/check_audio.py @@ -0,0 +1,88 @@ +import json +import os +import soundfile as sf +from pathlib import Path + +DATA_CONFIG = [ + {"name": "gigaspeech", "subset": "xl", "split": "train"}, + {"name": "librispeech", "subset": "train-clean-360"}, + {"name": "commonvoice", "subset": "validated"} +] + +ROOTS = { + "gigaspeech": "/home/work/AIDAS/data/audio/GigaSpeech", + "librispeech": "/home/work/AIDAS/data/audio/LibriSpeech", + "commonvoice": "/home/work/AIDAS/data/audio/commonvoice/cv-corpus-22.0-2025-06-20/en" +} + +def iter_gigaspeech(cfg): + import datasets + ds = datasets.load_dataset("speechcolab/gigaspeech", cfg["subset"], split=cfg["split"]) + for row in ds: + yield row["audio"]["path"] + +def iter_librispeech(cfg): + subset_root = Path(ROOTS["librispeech"]) / cfg["subset"] + for txt in subset_root.glob("*/**/*.txt"): + with txt.open() as f: + for line in f: + parts = line.strip().split() + if not parts: + continue + audio_id = parts[0] + speaker, chapter, _ = audio_id.split("-") + audio_path = subset_root / speaker / chapter / f"{audio_id}.flac" + yield audio_path + +def iter_commonvoice(cfg): + import pandas as pd + tsv = Path(ROOTS["commonvoice"]) / f"{cfg['subset']}.tsv" + df = pd.read_csv(tsv, sep="\t", usecols=["path"]) + clips_root = Path(ROOTS["commonvoice"]) / "clips" + for rel in df["path"]: + yield clips_root / rel + +DISPATCH = { + "gigaspeech": iter_gigaspeech, + "librispeech": iter_librispeech, + "commonvoice": iter_commonvoice, +} + +def main(): + total_sec = 0.0 + total_files = 0 + per_dataset = [] + + for cfg in DATA_CONFIG: + name = cfg["name"] + iterator = DISPATCH[name](cfg) + subset_total = 0.0 + subset_files = 0 + + for audio_path in iterator: + if not os.path.isfile(audio_path): + continue + info = sf.info(str(audio_path)) + duration = info.frames / info.samplerate + subset_total += duration + subset_files += 1 + total_sec += duration + total_files += 1 + + per_dataset.append({ + "name": name, + "subset": cfg.get("subset"), + "split": cfg.get("split"), + "num_files": subset_files, + "avg_seconds": subset_total / subset_files if subset_files else 0.0, + }) + + summary = { + "total_files": total_files, + "overall_avg_seconds": total_sec / total_files if total_files else 0.0, + "per_dataset": per_dataset, + } + print(json.dumps(summary, indent=2)) + +if __name__ == "__main__": + main() diff --git a/MMaDA/script/debug_llavavid_decode.py b/MMaDA/script/debug_llavavid_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d103ba7eb34993de7430f3b3eae2dc40287167 --- /dev/null +++ b/MMaDA/script/debug_llavavid_decode.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +""" +Utility script to stress-test LLaVA-Video frame decoding in isolation. + +This runs the `VideoCaptionDataset` loader on a single node so that we can +watch for files that consistently time out or wedged dataloader workers. +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +import torch +from torch.utils.data import DataLoader + +ROOT_DIR = Path(__file__).resolve().parents[2] +if str(ROOT_DIR) not in sys.path: + sys.path.insert(0, str(ROOT_DIR)) + +from training import data as video_data_module # noqa: E402 +from training.data import VideoCaptionDataset # noqa: E402 +from training.utils import image_transform as default_image_transform # noqa: E402 + + +def _resolve_llavavid_root(root_arg: Optional[str]) -> Path: + if root_arg: + root = Path(root_arg).expanduser().resolve() + else: + root = ROOT_DIR / "data" / "video" / "LLaVA-Video-178K" + if not root.exists(): + raise FileNotFoundError(f"LLaVA-Video root directory not found: {root}") + return root + + +def _identity_collate(batch: List[Optional[Dict[str, Any]]]) -> List[Dict[str, Any]]: + """Drop `None` samples that VideoCaptionDataset returns after repeated failures.""" + filtered = [sample for sample in batch if sample is not None] + return filtered + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Decode-check LLaVA-Video samples with the existing dataset logic." + ) + parser.add_argument( + "--llavavid-root", + type=str, + default=None, + help="Path to the LLaVA-Video-178K cache directory. Defaults to data/video/LLaVA-Video-178K relative to repo root.", + ) + parser.add_argument( + "--num-samples", + type=int, + default=256, + help=( + "Number of samples to attempt decoding (per DataLoader worker collectively). " + "Set to -1 to sweep the entire dataset once." + ), + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size for the diagnostic DataLoader.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of DataLoader workers to spawn. Set to match your training run.", + ) + parser.add_argument( + "--num-frames", + type=int, + default=8, + help="Number of frames to request from load_video_mp4.", + ) + parser.add_argument( + "--resolution", + type=int, + default=256, + help="Resolution passed to the dataset transform.", + ) + parser.add_argument( + "--sample-method", + type=str, + default="uniform", + choices=("uniform", "random"), + help="Frame sampling strategy.", + ) + parser.add_argument( + "--report-every", + type=int, + default=10, + help="Print a progress line every N successfully decoded samples.", + ) + parser.add_argument( + "--timeout", + type=float, + default=30.0, + help="Maximum seconds to allow a batch to hang before treating it as a stall.", + ) + return parser.parse_args() + + +def _maybe_set_thread_limits() -> None: + # Avoid oversubscribing CPU threads when the loader uses multiple workers. + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") + os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") + + +def main() -> None: + args = _parse_args() + _maybe_set_thread_limits() + + llavavid_root = _resolve_llavavid_root(args.llavavid_root) + print(f"[INFO] Using LLaVA-Video root: {llavavid_root}") + + original_loader = video_data_module.load_video_mp4 + + def traced_loader(*loader_args, **loader_kwargs): + video_path = loader_kwargs.get("video_path") + if video_path is None and loader_args: + video_path = loader_args[0] + start = time.time() + try: + frames = original_loader(*loader_args, **loader_kwargs) + except Exception as exc: # pylint: disable=broad-except + duration = time.time() - start + print(f"[ERROR] {video_path} raised {exc.__class__.__name__} after {duration:.2f}s: {exc}") + raise + duration = time.time() - start + status = "OK" if frames else "NONE" + print(f"[TRACE] {status:>4} | {duration:6.2f}s | {video_path}") + return frames + + video_data_module.load_video_mp4 = traced_loader + + try: + dataset = VideoCaptionDataset( + transform=default_image_transform, + tokenizer=None, + max_seq_length=256, + resolution=args.resolution, + dataset_name="llavavid", + llavavid_path=str(llavavid_root), + llavavid_local_files_only=True, + sample_method=args.sample_method, + num_frames=args.num_frames, + ) + + if len(dataset) == 0: + print("[ERROR] Dataset returned zero length. Check the root directory/config.") + sys.exit(1) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + collate_fn=_identity_collate, + pin_memory=False, + drop_last=False, + ) + + print( + f"[INFO] Starting decode sweep: " + f"{args.num_samples} samples, batch_size={args.batch_size}, num_workers={args.num_workers}" + ) + + decoded = 0 + attempted = 0 + failed = 0 + start_time = time.time() + last_report = start_time + + for batch_idx, batch in enumerate(dataloader, start=1): + expected = args.batch_size + actual = len(batch) + + attempted += expected + failed += max(expected - actual, 0) + decoded += sum(1 for sample in batch if sample.get("video")) + + if args.num_samples > 0 and decoded >= args.num_samples: + break + + now = time.time() + if args.report_every > 0 and decoded and decoded % args.report_every == 0: + elapsed = now - last_report + total_elapsed = now - start_time + print( + f"[INFO] {decoded} successful samples " + f"(attempted={attempted}, failed={failed}) " + f"in {total_elapsed:.1f}s (+{elapsed:.1f}s since last report)." + ) + last_report = now + + if now - start_time > args.timeout: + print( + f"[WARN] Exceeded timeout of {args.timeout}s without reaching target samples." + ) + break + + total_elapsed = time.time() - start_time + print( + f"[RESULT] Completed sweep: decoded={decoded}, attempted={attempted}, " + f"failed={failed}, elapsed={total_elapsed:.1f}s." + ) + finally: + video_data_module.load_video_mp4 = original_loader + + +if __name__ == "__main__": + main() diff --git a/MMaDA/script/debug_speech_dataloader.py b/MMaDA/script/debug_speech_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..a63d794994d516774baf7aee042613d88b56b78a --- /dev/null +++ b/MMaDA/script/debug_speech_dataloader.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +"""Utility to reproduce and debug the speech DataLoader used in training. + +This script pulls the speech dataset configuration from the Omada +instruction-tuning config, instantiates the same `MixedSpeechTextDataset`, and +iterates a configurable number of batches while measuring how long each fetch +takes. Use it to spot slow or stuck samples without launching the full training +job. + +Typical usage:: + + python AIDAS/MMaDA/script/debug_speech_dataloader.py \ + --config AIDAS/MMaDA/configs/omada_instruction_tuning.yaml \ + --flow s2t --max-batches 5 --num-workers 1 --timeout 0 + +Pass `--inspect-items` for a direct `dataset[idx]` sweep when a specific sample +seems suspicious. +""" + +from __future__ import annotations + +import argparse +import itertools +import logging +import sys +import time +from pathlib import Path +from typing import Any, Iterable, List + +from omegaconf import OmegaConf +from torch.utils.data import DataLoader + +from MMaDA.training.data import MixedSpeechTextDataset + + +def _collate_fn_audio(batch: List[dict[str, Any]]) -> dict[str, List[Any]]: + """Match the collate function used in training for speech flows.""" + + return { + "audio_path": [item["audio_path"] for item in batch], + "text": [item["text"] for item in batch], + "audio_tokens": [item.get("audio_tokens") for item in batch], + } + + +def _as_list_of_dicts(cfg_fragment: Any) -> List[dict[str, Any]]: + container = OmegaConf.to_container(cfg_fragment, resolve=True) + if not isinstance(container, Iterable): # pragma: no cover - sanity guard + raise TypeError("audio_data config must be a list of dataset dicts") + return list(container) # type: ignore[arg-type] + + +def _build_dataset(cfg) -> MixedSpeechTextDataset: + dataset_cfg = cfg.dataset.params + audio_data_cfg = _as_list_of_dicts(dataset_cfg.audio_data) + return MixedSpeechTextDataset(audio_data_cfg) + + +def _log_batch_summary(idx: int, batch: dict[str, List[Any]], elapsed: float) -> None: + audio_paths = batch.get("audio_path", []) + sample = audio_paths[0] if audio_paths else "" + logging.info( + "batch=%d size=%d elapsed=%.2fs sample=%s", + idx, + len(audio_paths), + elapsed, + sample, + ) + + +def _inspect_items(dataset: MixedSpeechTextDataset, max_items: int) -> None: + logging.info("Inspecting individual dataset items (max=%d)", max_items) + for idx in itertools.islice(range(len(dataset)), max_items): + tick = time.perf_counter() + try: + item = dataset[idx] + except Exception as exc: # pragma: no cover - diagnostic path + logging.error("idx=%d failed: %s", idx, exc) + continue + elapsed = time.perf_counter() - tick + logging.info( + "idx=%d elapsed=%.2fs path=%s text_len=%d tokens=%s", + idx, + elapsed, + item.get("audio_path"), + len(item.get("text", "")), + "cached" if item.get("audio_tokens") is not None else "None", + ) + + +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--config", + type=Path, + default=Path("AIDAS/MMaDA/configs/omada_instruction_tuning.yaml"), + help="Path to the training config YAML", + ) + parser.add_argument( + "--flow", + choices=["s2t", "t2s"], + default="s2t", + help="Which speech flow's batch size defaults to use", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Override batch size (defaults to config.training.batch_size_)", + ) + parser.add_argument( + "--num-workers", + type=int, + default=None, + help="Override DataLoader workers (defaults to config.dataset.params.num_workers)", + ) + parser.add_argument( + "--persistent-workers", + action="store_true", + help="Enable persistent workers regardless of config", + ) + parser.add_argument( + "--timeout", + type=float, + default=None, + help="DataLoader timeout in seconds (defaults to config.dataset.params.dataloader_timeout)", + ) + parser.add_argument( + "--max-batches", + type=int, + default=10, + help="Number of batches to iterate (0 means run through the entire dataset)", + ) + parser.add_argument( + "--inspect-items", + type=int, + default=0, + help="If >0, bypass the DataLoader and inspect this many individual dataset items first", + ) + parser.add_argument( + "--prefetch-factor", + type=int, + default=None, + help="Optional override for DataLoader prefetch_factor", + ) + parser.add_argument( + "--log-level", + default="INFO", + help="Logging level", + ) + return parser.parse_args(argv) + + +def main(argv: List[str]) -> int: + args = parse_args(argv) + logging.basicConfig( + level=getattr(logging, args.log_level.upper(), logging.INFO), + format="%(asctime)s | %(levelname)s | %(message)s", + ) + + cfg = OmegaConf.load(args.config) + dataset = _build_dataset(cfg) + + if args.inspect_items: + _inspect_items(dataset, args.inspect_items) + + dataset_params = cfg.dataset.params + batch_size = args.batch_size or getattr(cfg.training, f"batch_size_{args.flow}") + num_workers = args.num_workers if args.num_workers is not None else dataset_params.num_workers + timeout = args.timeout if args.timeout is not None else dataset_params.dataloader_timeout + + if num_workers == 0: + persistent_workers = False + else: + persistent_workers = args.persistent_workers or bool(dataset_params.persistent_workers) + + dataloader_kwargs = { + "dataset": dataset, + "batch_size": batch_size, + "shuffle": False, + "num_workers": num_workers, + "drop_last": True, + "pin_memory": bool(dataset_params.pin_memory), + "timeout": timeout, + "persistent_workers": persistent_workers, + "collate_fn": _collate_fn_audio, + } + if args.prefetch_factor is not None and num_workers > 0: + dataloader_kwargs["prefetch_factor"] = args.prefetch_factor + + logging.info( + "Starting DataLoader debug: batch_size=%d num_workers=%d timeout=%s persistent=%s", + batch_size, + num_workers, + timeout, + persistent_workers, + ) + + dataloader = DataLoader(**dataloader_kwargs) + + max_batches = args.max_batches + iterator = iter(dataloader) + + processed = 0 + while True: + if max_batches and processed >= max_batches: + break + tick = time.perf_counter() + try: + batch = next(iterator) + except StopIteration: + logging.info("Reached end of DataLoader after %d batches", processed) + break + elapsed = time.perf_counter() - tick + _log_batch_summary(processed, batch, elapsed) + processed += 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/MMaDA/script/emova.sh b/MMaDA/script/emova.sh new file mode 100644 index 0000000000000000000000000000000000000000..624eef4a1a7e6e401fb3bf418025681d8d7a63b4 --- /dev/null +++ b/MMaDA/script/emova.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=6,7 +export NUM_GPUS=2 + +torchrun --nproc_per_node=$NUM_GPUS --master_port=12351 /home/work/AIDAS/MMaDA/eval_emova.py \ No newline at end of file diff --git a/MMaDA/script/eval_i2i.sh b/MMaDA/script/eval_i2i.sh new file mode 100644 index 0000000000000000000000000000000000000000..9926eef07786fe434704a36c649c044770965dff --- /dev/null +++ b/MMaDA/script/eval_i2i.sh @@ -0,0 +1 @@ +python3 inference_i2i.py config=configs/mmada_demo.yaml guidance_scale=3.5 generation_timesteps=50 \ No newline at end of file diff --git a/MMaDA/script/eval_s2t.sh b/MMaDA/script/eval_s2t.sh new file mode 100644 index 0000000000000000000000000000000000000000..a2d5ee3a478132096cdade56fee2e5d2f6fd6373 --- /dev/null +++ b/MMaDA/script/eval_s2t.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +torchrun --nproc_per_node=8 MMaDA/inference_s2t_emova.py \ + config=MMaDA/configs/mmada_demo_s2t.yaml \ + --train_step 145000\ + --remasking "low_confidence"\ + --generation_step 128\ + --new_tok 128\ + --block_length 64 \ No newline at end of file diff --git a/MMaDA/script/eval_s2t_merging.sh b/MMaDA/script/eval_s2t_merging.sh new file mode 100644 index 0000000000000000000000000000000000000000..282f2f607bdbfdfce07d239cfb77e272606fd205 --- /dev/null +++ b/MMaDA/script/eval_s2t_merging.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +BASE_DIR="/home/work/AIDAS/ckpts/merged_model" + +merge_types=("average_merge_alpha" "hf_common_merge_alpha" "no_vocab_merge_alpha") + +alphas=(0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0) + +for merge_type in "${merge_types[@]}"; do + for alpha in "${alphas[@]}"; do + CKPT_PATH="${BASE_DIR}/${merge_type}_${alpha}" + + echo "==========================================================" + echo "Running with ckpt_path=${CKPT_PATH}" + echo "==========================================================" + + torchrun --nproc_per_node=8 MMaDA/inference_s2t_emova.py \ + config=MMaDA/configs/mmada_demo_s2t.yaml \ + --ckpt_path "${CKPT_PATH}" \ + --train_step 0 \ + --remasking "low_confidence" \ + --generation_step 128 \ + --new_tok 128 \ + --block_length 64 + done +done + +echo "Grid search finished." \ No newline at end of file diff --git a/MMaDA/script/eval_t2s.sh b/MMaDA/script/eval_t2s.sh new file mode 100644 index 0000000000000000000000000000000000000000..a968dea764cb84e8bfc55f51c3034d1587ec0bfe --- /dev/null +++ b/MMaDA/script/eval_t2s.sh @@ -0,0 +1,3 @@ +# Should be done in AIDAS folder + +python3 MMaDA/inference_t2s.py config=MMaDA/configs/mmada_demo_speech.yaml batch_size=1 validation_prompts_file=MMaDA/validation_prompts/text2speech_prompts_tmp.txt guidance_scale=3.5 generation_timesteps=25 \ No newline at end of file diff --git a/MMaDA/script/eval_t2s_emova.sh b/MMaDA/script/eval_t2s_emova.sh new file mode 100755 index 0000000000000000000000000000000000000000..4e76b0fdada92459580c6d8cd155c74a648b4aec --- /dev/null +++ b/MMaDA/script/eval_t2s_emova.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +torchrun --nproc_per_node=1 MMaDA/inference_t2s_emova.py \ + config=MMaDA/configs/mmada_demo_s2t.yaml \ + --train_step 135000\ + --guidance_scale 0.0\ + --timesteps 256\ + --speech_token_length 256 diff --git a/MMaDA/script/eval_t2s_grid.sh b/MMaDA/script/eval_t2s_grid.sh new file mode 100644 index 0000000000000000000000000000000000000000..7ccd88feda9e62e26fc9938ff214df12652ec734 --- /dev/null +++ b/MMaDA/script/eval_t2s_grid.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +timesteps=(12) +speech_token_lengths=(150) +train_steps=($(seq 135000 -50000 35000)) + +for ts in "${timesteps[@]}" +do + for stl in "${speech_token_lengths[@]}" + do + for step in "${train_steps[@]}" + do + echo "==========================================================" + echo "Running with: train_step=${step}, timesteps=${ts}, speech_token_length=${stl}, guidance_scale=0.75" + echo "==========================================================" + + torchrun --nproc_per_node=8 MMaDA/inference_t2s_emova.py \ + config=MMaDA/configs/mmada_demo_s2t.yaml \ + --train_step ${step} \ + --guidance_scale 0.75 \ + --timesteps ${ts} \ + --speech_token_length ${stl} + done + done +done + +echo "Grid search finished." \ No newline at end of file diff --git a/MMaDA/script/eval_t2s_jake.sh b/MMaDA/script/eval_t2s_jake.sh new file mode 100644 index 0000000000000000000000000000000000000000..93556f7e85fcfc3d73a80960423b223233809238 --- /dev/null +++ b/MMaDA/script/eval_t2s_jake.sh @@ -0,0 +1,3 @@ +# Should be done in AIDAS folder + +python3 MMaDA/inference_t2s.py config=MMaDA/configs/mmada_demo_speech.yaml batch_size=1 validation_prompts_file=MMaDA/validation_prompts/text2speech_prompts.txt guidance_scale=3.5 generation_timesteps=25 \ No newline at end of file diff --git a/MMaDA/script/hostfile.txt b/MMaDA/script/hostfile.txt new file mode 100644 index 0000000000000000000000000000000000000000..10cb7dbfa35ef84202a7eb67103a11c4f159e68f --- /dev/null +++ b/MMaDA/script/hostfile.txt @@ -0,0 +1,4 @@ +main1 slots=8 +sub1 slots=8 +sub2 slots=8 +sub3 slots=8 \ No newline at end of file diff --git a/MMaDA/script/inference_mmu.sh b/MMaDA/script/inference_mmu.sh new file mode 100644 index 0000000000000000000000000000000000000000..914747c43804cdfca6f74f22cb9e024395beb9a2 --- /dev/null +++ b/MMaDA/script/inference_mmu.sh @@ -0,0 +1,7 @@ +export CUDA_VISIBLE_DEVICES=0,1 +echo "Using CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" + +python3 inference_mmu_ori.py \ + config=configs/mmada_demo_video.yaml \ + mmu_image_root=./mmu_validation \ + question='Please describe this image in detail.' \ No newline at end of file diff --git a/MMaDA/script/inference_video.sh b/MMaDA/script/inference_video.sh new file mode 100644 index 0000000000000000000000000000000000000000..1de86f2f41170d7c48b38cab54ba9c00c8d7e4e2 --- /dev/null +++ b/MMaDA/script/inference_video.sh @@ -0,0 +1,8 @@ +export CUDA_VISIBLE_DEVICES=0,1 +echo "Using CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" + +python3 inference_v2t.py \ + config=configs/mmada_demo_video.yaml \ + video_image_root=/home/work/AIDAS/video/demo \ + step=70000\ + question="Please provide a detailed description of the video." \ No newline at end of file diff --git a/MMaDA/script/train_i2i.sh b/MMaDA/script/train_i2i.sh new file mode 100644 index 0000000000000000000000000000000000000000..49edf59967cb3ee5e5f216804fb4d49671afd0fe --- /dev/null +++ b/MMaDA/script/train_i2i.sh @@ -0,0 +1,10 @@ +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml --machine_rank 2 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_mmada_i2i.py config=/home/work/AIDAS/MMaDA/configs/mmada_pretraining_i2i.yaml + + +nohup accelerate launch \ + --config_file /home/work/AIDAS/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml \ + --machine_rank 2 \ + --main_process_port=8888 \ + /home/work/AIDAS/MMaDA/training/train_mmada_i2i.py \ + config=/home/work/AIDAS/MMaDA/configs/mmada_pretraining_i2i.yaml \ + > train_mmada_i2i_sub2.log 2>&1 & \ No newline at end of file diff --git a/MMaDA/script/train_omada_instruction.sh b/MMaDA/script/train_omada_instruction.sh new file mode 100644 index 0000000000000000000000000000000000000000..0be46efd49f2a1d7145369459ca2fbede55a4a2b --- /dev/null +++ b/MMaDA/script/train_omada_instruction.sh @@ -0,0 +1,110 @@ +#!/usr/bin/env bash + +# Example manual launches retained for reference: +# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 0 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_inst.py config=/home/work/AIDAS/MMaDA/configs/omada_instruction_tuning.yaml +# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 1 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_inst.py config=/home/work/AIDAS/MMaDA/configs/omada_instruction_tuning.yaml +# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 2 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_inst.py config=/home/work/AIDAS/MMaDA/configs/omada_instruction_tuning.yaml +# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 3 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_inst.py config=/home/work/AIDAS/MMaDA/configs/omada_instruction_tuning.yaml + +export AIDAS_TRAIN_HOSTS="main1 sub1 sub2 sub3" + +set -euo pipefail + +PROJECT_ROOT="/home/work/AIDAS" +CONFIG_FILE="${PROJECT_ROOT}/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml" +TRAIN_SCRIPT="${PROJECT_ROOT}/MMaDA/training/train_omada_inst.py" +EXPERIMENT_CFG="${PROJECT_ROOT}/MMaDA/configs/omada_instruction_tuning.yaml" +LOG_DIR="${PROJECT_ROOT}/logs" +MAIN_PORT="${MAIN_PORT:-8888}" +REMOTE_SETUP="${REMOTE_SETUP:-source ~/.bashrc && conda activate mmada}" +NCCL_DEBUG_LEVEL="${NCCL_DEBUG_LEVEL:-INFO}" + +if [[ -z "${AIDAS_TRAIN_HOSTS:-}" ]]; then + echo "Set AIDAS_TRAIN_HOSTS=\"host0 host1 host2 host3\" before running this script." >&2 + exit 1 +fi + +read -r -a HOSTS <<< "${AIDAS_TRAIN_HOSTS}" +NUM_MACHINES=${#HOSTS[@]} +if (( NUM_MACHINES == 0 )); then + echo "AIDAS_TRAIN_HOSTS is empty." >&2 + exit 1 +fi + +mkdir -p "$LOG_DIR" + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +declare -a PIDS=() +declare -a HOST_LABELS=() + +timestamp_lines() { + while IFS= read -r line; do + printf '%s %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$line" + done +} + +stop_all() { + if (( ${#PIDS[@]} == 0 )); then + return + fi + echo "Stopping launched processes..." + for pid in "${PIDS[@]}"; do + if [[ -n "${pid:-}" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" >/dev/null 2>&1 || true + fi + done +} + +on_signal() { + echo "Signal received, terminating all ranks." + stop_all + exit 1 +} +trap on_signal INT TERM + +launch_rank() { + local host="$1" + local rank="$2" + local log_file="$3" + local host_label="$4" + + local base_cmd + if [[ -n "${REMOTE_SETUP}" ]]; then + base_cmd="${REMOTE_SETUP} && cd ${PROJECT_ROOT} && env NCCL_DEBUG=${NCCL_DEBUG_LEVEL} NCCL_SHM_DISABLE=1 NCCL_ASYNC_ERROR_HANDLING=1 accelerate launch --config_file ${CONFIG_FILE} --num_machines ${NUM_MACHINES} --machine_rank ${rank} --main_process_port ${MAIN_PORT} ${TRAIN_SCRIPT} config=${EXPERIMENT_CFG}" + else + base_cmd="cd ${PROJECT_ROOT} && env NCCL_DEBUG=${NCCL_DEBUG_LEVEL} NCCL_SHM_DISABLE=1 NCCL_ASYNC_ERROR_HANDLING=1 accelerate launch --config_file ${CONFIG_FILE} --num_machines ${NUM_MACHINES} --machine_rank ${rank} --main_process_port ${MAIN_PORT} ${TRAIN_SCRIPT} config=${EXPERIMENT_CFG}" + fi + local escaped_cmd + escaped_cmd=$(printf '%q' "$base_cmd") + + if [[ "$host" == "localhost" || "$host" == "$(hostname)" || "$host" == "$(hostname -f)" ]]; then + echo "[rank ${rank}] running locally (${host_label}), logging to ${log_file}" + stdbuf -oL -eL bash -lc "$base_cmd" 2>&1 | timestamp_lines >"$log_file" & + else + local dest="${SSH_USER:-$USER}@${host}" + echo "[rank ${rank}] ssh ${dest}, logging to ${log_file}" + ssh "$dest" "bash -lc $escaped_cmd" 2>&1 | timestamp_lines >"$log_file" & + fi + PIDS[$rank]=$! + HOST_LABELS[$rank]="$host_label" +} + +for idx in "${!HOSTS[@]}"; do + host="${HOSTS[$idx]}" + safe_host=${host//[^A-Za-z0-9_.-]/_} + log_file="${LOG_DIR}/train_inst_${TIMESTAMP}_rank${idx}_${safe_host}.log" + launch_rank "$host" "$idx" "$log_file" "$safe_host" +done + +echo "All nodes launched. Tail logs under ${LOG_DIR}." + +for rank in "${!PIDS[@]}"; do + pid="${PIDS[$rank]}" + [[ -n "${pid:-}" ]] || continue + if ! wait "$pid"; then + status=$? + echo "[rank ${rank}] (${HOST_LABELS[$rank]}) exited with status ${status}" + stop_all + exit $status + fi +done diff --git a/MMaDA/script/train_omada_stage1.sh b/MMaDA/script/train_omada_stage1.sh new file mode 100755 index 0000000000000000000000000000000000000000..274169f74cc0562d81009768d637ac1b958c7b4a --- /dev/null +++ b/MMaDA/script/train_omada_stage1.sh @@ -0,0 +1,29 @@ +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 3 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_stage1.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1.yaml + +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml /home/work/AIDAS/MMaDA/training/train_omada_stage1.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1.yaml + +# 4 Nodes +export NCCL_SHM_DISABLE=1 && export NCCL_ASYNC_ERROR_HANDLING=1 && accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 0 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_stage1-3.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1-3.yaml +export NCCL_SHM_DISABLE=1 && export NCCL_ASYNC_ERROR_HANDLING=1 && accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 1 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_stage1-3.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1-3.yaml +export NCCL_SHM_DISABLE=1 && export NCCL_ASYNC_ERROR_HANDLING=1 && accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 2 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_stage1-3.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1-3.yaml +export NCCL_SHM_DISABLE=1 && export NCCL_ASYNC_ERROR_HANDLING=1 && accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml --machine_rank 3 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_stage1-3.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1-3.yaml +# 2 Nodes +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas2.yaml --machine_rank 0 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_stage1.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1-2.yaml +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas2.yaml --machine_rank 1 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_omada_stage1.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1-2.yaml +# Single Node +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml /home/work/AIDAS/MMaDA/training/train_omada_stage1-2.py config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1-2.yaml + + +LOG_DIR=/home/work/AIDAS/logs +mkdir -p "$LOG_DIR" +LOG_FILE="${LOG_DIR}/train_stage1_$(date +%Y%m%d_%H%M%S).log" + +export NCCL_SHM_DISABLE=1 +export NCCL_ASYNC_ERROR_HANDLING=1 +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml \ + --machine_rank 0 \ + --main_process_port=8888 \ + /home/work/AIDAS/MMaDA/training/train_omada_stage1-3.py \ + config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_stage1-4.yaml \ + 2>&1 | tee "$LOG_FILE" + diff --git a/MMaDA/script/train_s2t.sh b/MMaDA/script/train_s2t.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e04522d696143b92bad4597a8fa800214e43054 --- /dev/null +++ b/MMaDA/script/train_s2t.sh @@ -0,0 +1,3 @@ +# accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml --machine_rank 0 --main_process_port=8888 /home/work/AIDAS/MMaDA/training/train_mmada_s2t.py config=/home/work/AIDAS/MMaDA/configs/mmada_pretraining_t2s.yaml + +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml /home/work/AIDAS/MMaDA/training/train_mmada_s2t.py config=/home/work/AIDAS/MMaDA/configs/mmada_pretraining_s2t.yaml diff --git a/MMaDA/script/train_t2s.sh b/MMaDA/script/train_t2s.sh new file mode 100644 index 0000000000000000000000000000000000000000..68e965ef22b45680c670ec8bf51f5603bdae9f2e --- /dev/null +++ b/MMaDA/script/train_t2s.sh @@ -0,0 +1 @@ +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml /home/work/AIDAS/MMaDA/training/train_mmada_t2s.py config=/home/work/AIDAS/MMaDA/configs/mmada_pretraining_t2s.yaml \ No newline at end of file diff --git a/MMaDA/script/train_test.sh b/MMaDA/script/train_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..59baad45cf8598bb5c13ecab8755c35c40967a37 --- /dev/null +++ b/MMaDA/script/train_test.sh @@ -0,0 +1 @@ +accelerate launch --config_file /home/work/AIDAS/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml /home/work/AIDAS/MMaDA/training/train_mmada_t2s_test.py config=/home/work/AIDAS/MMaDA/configs/mmada_pretraining_t2s.yaml \ No newline at end of file diff --git a/MMaDA/script/train_v2s.sh b/MMaDA/script/train_v2s.sh new file mode 100644 index 0000000000000000000000000000000000000000..8dde97094a4e5c0372d6c5303949fe49e8bddde8 --- /dev/null +++ b/MMaDA/script/train_v2s.sh @@ -0,0 +1,9 @@ +ONE_GPU_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/1_gpu.yaml +ONE_NODE_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml +TWO_NODE_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml + +accelerate launch \ + --config_file ${ONE_NODE_CONFIG} \ + --main_process_port=22223 \ + /home/work/AIDAS/MMaDA/training/train_mmada_v2s.py \ + config=/home/work/AIDAS/MMaDA/configs/mmada_pretraining_v2s.yaml \ No newline at end of file diff --git a/MMaDA/script/train_v2t.sh b/MMaDA/script/train_v2t.sh new file mode 100644 index 0000000000000000000000000000000000000000..63e9ed334b66593436afb4e99e25e7df1db1ce89 --- /dev/null +++ b/MMaDA/script/train_v2t.sh @@ -0,0 +1,9 @@ +ONE_GPU_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/1_gpu.yaml +ONE_NODE_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml +TWO_NODE_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml + +accelerate launch \ + --config_file ${ONE_NODE_CONFIG} \ + --main_process_port=22223 \ + /home/work/AIDAS/MMaDA/training/train_mmada_v2t.py \ + config=/home/work/AIDAS/MMaDA/configs/mmada_pretraining_v2t.yaml \ No newline at end of file diff --git a/MMaDA/script/train_v2t_inst.sh b/MMaDA/script/train_v2t_inst.sh new file mode 100644 index 0000000000000000000000000000000000000000..9b13e88804692e1d20eae15d30e64d51a8e4fa34 --- /dev/null +++ b/MMaDA/script/train_v2t_inst.sh @@ -0,0 +1,9 @@ +ONE_GPU_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/1_gpu.yaml +ONE_NODE_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml +TWO_NODE_CONFIG=/home/work/AIDAS/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml + +accelerate launch \ + --config_file ${ONE_NODE_CONFIG} \ + --main_process_port=22223 \ + /home/work/AIDAS/MMaDA/training/train_mmada_v2t_inst.py \ + config=/home/work/AIDAS/MMaDA/configs/omada_pretraining_v2t_inst.yaml \ No newline at end of file diff --git a/MMaDA/tools/run_dataloaders.py b/MMaDA/tools/run_dataloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1bbaf14ab094ce1ef255061292bc47735b7703 --- /dev/null +++ b/MMaDA/tools/run_dataloaders.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +""" +Utility script to sanity-check data loaders defined in train_omada_inst.py +without constructing the full training stack. + +Example: + python MMaDA/tools/run_dataloaders.py config=MMaDA/configs/omada_instruction_tuning.yaml \ + --flows v2t --num-workers 0 --max-batches 10 +""" + +from __future__ import annotations + +import argparse +import logging +import os +import sys +import time +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +# Ensure repository root is importable when executing from arbitrary cwd. +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +from training.data import VideoCaptionDataset # noqa: E402 +from training.utils import image_transform # noqa: E402 + +LOGGER = logging.getLogger("run_dataloaders") + + +def _parse_args() -> Tuple[argparse.Namespace, DictConfig]: + parser = argparse.ArgumentParser(description="Run Omada dataloaders without the trainer.") + parser.add_argument( + "--flows", + default="v2t", + help="Comma separated list of dataloaders to exercise (currently supports: v2t). " + "Use 'all' to run every available flow.", + ) + parser.add_argument( + "--max-batches", + type=int, + default=0, + help="Stop after this many batches per loader (0 means iterate the entire epoch).", + ) + parser.add_argument( + "--num-workers", + type=int, + default=None, + help="Override DataLoader num_workers (falls back to config.dataset.params.num_workers).", + ) + parser.add_argument( + "--persistent-workers", + dest="persistent_workers", + action="store_true", + help="Force persistent_workers=True regardless of config.", + ) + parser.add_argument( + "--no-persistent-workers", + dest="persistent_workers", + action="store_false", + help="Force persistent_workers=False regardless of config.", + ) + parser.set_defaults(persistent_workers=None) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Torch manual seed for reproducibility.", + ) + + args, unknown = parser.parse_known_args() + + cli_conf = OmegaConf.from_cli(unknown) + if "config" not in cli_conf: + parser.error("Please provide the training config via 'config=/path/to/config.yaml'.") + + yaml_conf = OmegaConf.load(cli_conf.config) + merged = OmegaConf.merge(yaml_conf, cli_conf) + + return args, merged + + +def _collate_v2t(batch: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Minimal collate fn mirroring train_omada_inst.collate_fn_v2t.""" + filtered: List[Dict[str, Any]] = [sample for sample in batch if sample is not None] + if not filtered: + return None + + videos: List[torch.Tensor] = [] + captions: List[Any] = [] + for sample in filtered: + frames = sample.get("video") + caption = sample.get("caption") + if frames is None: + continue + try: + tensor = torch.stack(frames, dim=0) + except Exception as exc: + LOGGER.exception("Failed to stack frames for sample %s", sample) + raise exc + videos.append(tensor) + captions.append(caption) + + if not videos: + return None + + return { + "video": torch.stack(videos, dim=0), + "captions": captions, + } + + +def _build_v2t_loader( + cfg: DictConfig, + tokenizer, + *, + num_workers: int, + persistent_workers: bool, + pin_memory: bool, +) -> DataLoader: + speech_cfg = getattr(cfg.dataset.params, "video_speech_dataset", {}) + if not isinstance(speech_cfg, dict): + speech_cfg = OmegaConf.to_container(speech_cfg, resolve=True) + speech_cfg = speech_cfg or {} + + dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=tokenizer, + max_seq_length=int(cfg.dataset.preprocessing.max_seq_length), + resolution=int(cfg.dataset.preprocessing.resolution), + sample_method=speech_cfg.get("sample_method", "uniform"), + dataset_name=speech_cfg.get("llavavid_dataset_name", "llavavid"), + num_frames=int(speech_cfg.get("num_frames", 8)), + ) + + batch_size = int(max(1, cfg.training.batch_size_v2t)) + LOGGER.info( + "Instantiated VideoCaptionDataset with %d samples; batch_size=%d num_workers=%d", + len(dataset), + batch_size, + num_workers, + ) + + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers if num_workers > 0 else False, + collate_fn=_collate_v2t, + drop_last=False, + ) + + +def _iterate_loader(name: str, loader: DataLoader, max_batches: int) -> None: + LOGGER.info("Starting iteration over '%s' (max_batches=%s)", name, max_batches or "full epoch") + start = time.time() + failures = 0 + processed = 0 + + try: + for step, batch in enumerate(loader, start=1): + if batch is None: + failures += 1 + LOGGER.warning("[%s] Received empty batch at step %d", name, step) + continue + + processed += batch["video"].size(0) + + if max_batches and step >= max_batches: + break + except Exception as exc: + LOGGER.exception("Loader '%s' raised an exception at batch %d", name, step) + raise exc + finally: + duration = time.time() - start + LOGGER.info( + "Finished '%s': steps=%d samples=%d failures=%d elapsed=%.2fs", + name, + step if 'step' in locals() else 0, + processed, + failures, + duration, + ) + + +def main() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%H:%M:%S", + ) + + args, cfg = _parse_args() + torch.manual_seed(args.seed) + + pin_memory = bool(getattr(cfg.dataset.params, "pin_memory", False)) + + if args.num_workers is None: + num_workers = int(getattr(cfg.dataset.params, "num_workers", 0)) + else: + num_workers = max(0, args.num_workers) + + if args.persistent_workers is None: + persistent_workers = bool(getattr(cfg.dataset.params, "persistent_workers", False)) + else: + persistent_workers = bool(args.persistent_workers) + + flows_arg = [item.strip().lower() for item in args.flows.split(",") if item.strip()] + if "all" in flows_arg: + flows = {"v2t"} + else: + flows = set(flows_arg) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model.omada.tokenizer_path, padding_side="left") + + loaders: Dict[str, DataLoader] = {} + if "v2t" in flows: + loaders["v2t"] = _build_v2t_loader( + cfg, + tokenizer, + num_workers=num_workers, + persistent_workers=persistent_workers, + pin_memory=pin_memory, + ) + + if not loaders: + LOGGER.error("No loaders selected. Supported flows: v2t") + sys.exit(1) + + for name, loader in loaders.items(): + _iterate_loader(name, loader, args.max_batches) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/__init__.py b/MMaDA/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb25cc918020c5db769289dfdec45476d7d073c --- /dev/null +++ b/MMaDA/training/__init__.py @@ -0,0 +1 @@ +# from .mmada_grpo_trainer import DiffusionGRPOTrainer \ No newline at end of file diff --git a/MMaDA/training/data.py b/MMaDA/training/data.py new file mode 100644 index 0000000000000000000000000000000000000000..48b4bdda0ed01f088c7f7775ade6e43d50ab8edd --- /dev/null +++ b/MMaDA/training/data.py @@ -0,0 +1,2805 @@ +# coding=utf-8 +# Copyright 2025 MMaDA Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bisect +import csv +import logging +import itertools +import json +import math +import os +import hashlib +import contextlib +from pathlib import Path +from accelerate import Accelerator +from itertools import chain + +# Video real-time? +import os.path as osp +import time +import requests + +import random +import re +import datasets +import pandas as pd +from functools import partial +from typing import List, Optional, Union, Dict, Any, Sequence +from glob import glob +from tqdm import tqdm +import numpy as np +import cv2 +from PIL import Image +import torch + +from dataclasses import dataclass +from datasets import Dataset as HFDataset +from datasets import load_dataset, get_dataset_config_names +from io import BytesIO + +Image.warnings.simplefilter('error', Image.DecompressionBombWarning) + +import webdataset as wds +import yaml +from braceexpand import braceexpand +from torch.utils.data import default_collate, Dataset +from torchvision import transforms +from transformers import PreTrainedTokenizer +from datasets import ( + load_dataset, + load_from_disk, + DatasetDict, + DownloadConfig, + get_dataset_config_names, + concatenate_datasets, +) +import warnings +from training.utils import image_transform as utils_image_transform, image_transform_squash as utils_image_transform_squash +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) + +logger = logging.getLogger(__name__) + +S2T_INSTRUCTION = ["Transcribe the given audio.", + + "Write down what you hear in the audio.", + + "Provide a transcript for the given speech.", + + "What does the speaker in the audio say?", + + "Convert the speech in the audio to text.", + + "Listen to the audio and write out the text."] + +T2S_INSTRUCTION = ["Generate speech for the given text.", + + "Read the given sentence aloud.", + + "Say the given words.", + + "Convert the given text into spoken audio.", + + "Speak the given text.", + + "Synthesize the text into speech."] + +V2T_INSTRUCTION = ["Describe the video in detail.", + + "Please provide a detailed description of the video.", + + "What is happening in the video?", + + "Describe the content of the video in detail.",] + +V2S_INSTRUCTION = [ + "Generate speech that describes the given video.", + "Narrate the events happening in the video.", + "Produce spoken audio describing the video content.", + "Convert the video into a detailed spoken narration.", + "Speak a description of what is shown in the video.", + "Synthesize speech that explains the content of the video.", +] + +person_token = ["a person", "someone", "somebody"] + +def replace_person_token(t): + "Used for CC12M - handles all case variations of tag" + t = re.sub(r"([,\s]*(and)*[,\s]*)+", " people ", t, flags=re.IGNORECASE) + + person_pattern = re.compile(r"", re.IGNORECASE) + while person_pattern.search(t): + match = person_pattern.search(t) + t = t[:match.start()] + f" {random.choice(person_token)} " + t[match.end():] + + return t + + +def filter_keys(key_set): + def _f(dictionary): + return {k: v for k, v in dictionary.items() if k in key_set} + + return _f + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None, src=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + if "fname" not in filesample.keys(): + print(f"fname not in filesample.keys(): {filesample}") + print(f"src: {src}") + continue + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) # [{fname,data,__url__}, ...] __url__ å­—ę®µę ‡čÆ†å½“å‰čÆ»å–ēš„ę–‡ä»¶ę„č‡Ŗå“ŖäøŖ tar 包 + samples = group_by_keys_nothrow(files, handler=handler, src=src) + return samples + + +def image_transform(sample, resolution=256): + image = sample["images"] + image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image) + image = transforms.CenterCrop((resolution, resolution))(image) + image = transforms.ToTensor()(image) + image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) + sample["images"] = image + return sample + +def image_transform_squash(sample, resolution=256): + image = sample["images"] + image = transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC)(image) + image = transforms.ToTensor()(image) + image = transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])(image) + sample["images"] = image + return sample + +def conditional_image_transform(sample, resolution=256): + url = sample.get("__url__", "") + special_datasets = ['ai2d', 'clevr', 'docvqa', 'geo'] + use_squash = False + for keyword in special_datasets: + if keyword in url: + use_squash = True + break + if use_squash: + return image_transform_squash(sample, resolution) + else: + return image_transform(sample, resolution) + + +def remove_prefix(caption): + caption = caption.replace('The image features ', '').replace('The image presents ', '').replace( + "The image you've sent is, ", '').replace("In the center of the image, ", '').replace( + "The image showcases ", '').replace("The image is ", '').replace( + "The image captures ", '').replace("In the given image ", '').replace( + "The image portrays ", '').replace("In the image, ", '').replace("In this image, we see ", '').replace( + "The image depicts ", '').replace("This is ", '').replace("In this image, ", '').replace( + "This image captures ", '') + + return caption + +def filter_long_samples(sample): + return sample.get('input_ids') is not None + + +class Text2ImageDataset: + def __init__( + self, + train_shards_path_or_url: Union[str, List[str]], + tokenizer: PreTrainedTokenizer, + max_seq_length: int, + num_train_examples: int, + per_gpu_batch_size: int, + global_batch_size: int, + num_workers: int, + resolution: int = 256, + shuffle_buffer_size: int = 1000, + pin_memory: bool = False, + persistent_workers: bool = False, + external_caption_path: Optional[str] = '', + external_journeydb_caption_path: Optional[str] = '', + external_laion12m_caption_path: Optional[str] = '', + external_cc12m_caption_path: Optional[str] = '', + external_text_to_image_2M_512_caption_path: Optional[str] = '', + external_ai2d_caption_path: Optional[str] = '', + external_clevr_caption_path: Optional[str] = '', + external_docvqa_caption_path: Optional[str] = '', + external_geo_caption_path: Optional[str] = '', + is_captioning: bool = False, + add_caption_prompt: bool = False, + long_caption: bool = True, + shuffle: bool = True, + ): + if f"{train_shards_path_or_url}.yaml" in os.listdir('./configs'): + with open(f"./configs/{train_shards_path_or_url}.yaml") as f: + train_shards_path_or_url = yaml.safe_load(f) + self.long_caption = long_caption + self.external_caption_path = external_caption_path + self.external_journeydb_caption_path = external_journeydb_caption_path + self.external_laion12m_caption_path = external_laion12m_caption_path + self.external_cc12m_caption_path = external_cc12m_caption_path + self.external_text_to_image_2M_512_caption_path = external_text_to_image_2M_512_caption_path + self.is_captioning = is_captioning + self.add_caption_prompt = add_caption_prompt + if self.add_caption_prompt: + with open("./training/questions.json") as f: + self.caption_prompt = json.load(f) + # self.caption_prompt = ['USER: \n' + prompt + ' ASSISTANT:' for prompt in self.caption_prompt] + self.caption_prompt = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|start_header_id|>assistant<|end_header_id|>\n' for prompt in self.caption_prompt] + else: + self.caption_prompt = None + + if external_journeydb_caption_path != '': + with open(external_journeydb_caption_path) as file: + self.journeydb_caption = json.load(file) + else: + self.journeydb_caption = None + + if external_ai2d_caption_path!= '': + self.ai2d_caption = pd.read_csv(external_ai2d_caption_path) + if external_clevr_caption_path!= '': + self.clevr_caption = pd.read_csv(external_clevr_caption_path) + if external_docvqa_caption_path!= '': + self.docvqa_caption = pd.read_csv(external_docvqa_caption_path) + if external_geo_caption_path!= '': + self.geo_caption = pd.read_csv(external_geo_caption_path) + + def tokenize(text): + if tokenizer is not None: + text = replace_person_token(text) + + encoding = tokenizer( + text, + truncation=True, + max_length=2 * max_seq_length, + padding=False, + return_tensors="pt" + ) + full_input_ids = encoding.input_ids[0] + + if len(full_input_ids) > max_seq_length: + return None + else: + return text + else: + return text + + + + if not isinstance(train_shards_path_or_url, str): + train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] + # flatten list using itertools + train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + + if external_caption_path != '': + processing_pipeline = [ + wds.decode("pil", handler=wds.ignore_and_continue), + wds.map(self.load_external_caption, handler=wds.ignore_and_continue), + wds.rename( + images="jpg;png;jpeg;webp", + input_ids="text;txt;caption", + handler=wds.warn_and_continue, + ), + wds.map(partial(conditional_image_transform, resolution=resolution), handler=wds.warn_and_continue), + wds.map(filter_keys(set(["images", "input_ids"]))), + wds.map_dict( + input_ids=tokenize, + handler=wds.warn_and_continue, + ), + wds.select(filter_long_samples), + ] + else: + processing_pipeline = [ + wds.decode("pil", handler=wds.ignore_and_continue), + wds.rename( + images="jpg;png;jpeg;webp", + input_ids="text;txt;caption", + handler=wds.warn_and_continue, + ), + wds.map(partial(conditional_image_transform, resolution=resolution), handler=wds.warn_and_continue), + wds.map(filter_keys(set(["images", "input_ids"]))), + wds.map_dict( + input_ids=tokenize, + handler=wds.warn_and_continue, + ), + wds.select(filter_long_samples), + ] + + pipeline = [ + wds.ResampledShards(train_shards_path_or_url), + tarfile_to_samples_nothrow, + wds.shuffle(shuffle_buffer_size), + *processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + + num_batches = math.ceil(num_train_examples / global_batch_size) + num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) + self._train_dataloader = wds.WebLoader( + self._train_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + # add meta-data to dataloader instance for convenience + self._train_dataloader.num_batches = num_batches + self._train_dataloader.num_samples = num_samples + + def load_external_caption(self, sample): + + if 'SA1B' in sample['__key__'] or 'sa' in sample['__key__']: + captionf = f"{self.external_caption_path}/{sample['__key__'].split('/')[-1]}.txt" + if os.path.exists(captionf): + with open(captionf, "r") as reader: + captions = reader.readlines()[0].replace('\n', '') + else: + captions = "" + + # for captioning + if self.is_captioning: + if self.add_caption_prompt is not None: + prompt = random.sample(self.caption_prompt, 1)[0] + sample['txt'] = prompt + captions + else: + sample['txt'] = captions + # for generation + else: + # randomly choose short and long captions + if random.random() < 0.5: + sample['txt'] = captions.split('.')[0] + else: + sample['txt'] = captions + + sample['txt'] = remove_prefix(sample['txt']) + + return sample + + elif 'laion' in sample['__url__']: + url_part = sample['__url__'].split('/')[-1].split('.')[0] + key = sample['__key__'].split('/')[-1] + captionf = os.path.join(self.external_laion12m_caption_path, url_part, f"{key}.caption") + + if os.path.exists(captionf): + with open(captionf, "r") as reader: + captions = reader.read().strip() + else: + captions = "" + + # for captioning + if self.is_captioning: + if self.add_caption_prompt is not None: + prompt = random.sample(self.caption_prompt, 1)[0] + sample['txt'] = prompt + captions + else: + sample['txt'] = captions + # for generation + else: + # randomly choose short and long captions + if random.random() < 0.5: + sample['txt'] = captions.split('.')[0] + else: + sample['txt'] = captions + + sample['txt'] = remove_prefix(sample['txt']) + + return sample + + elif 'cc12m' in sample['__url__']: + url_part = sample['__url__'].split('/')[-1].split('.')[0] + key = sample['__key__'].split('/')[-1] + captionf = os.path.join(self.external_cc12m_caption_path, url_part, f"{key}.caption") + + if os.path.exists(captionf): + with open(captionf, "r") as reader: + captions = reader.read().strip() + else: + captions = "" + + # for captioning + if self.is_captioning: + if self.add_caption_prompt is not None: + prompt = random.sample(self.caption_prompt, 1)[0] + sample['txt'] = prompt + captions + else: + sample['txt'] = captions + # for generation + else: + # randomly choose short and long captions + if random.random() < 0.5: + sample['txt'] = captions.split('.')[0] + else: + sample['txt'] = captions + sample['txt'] = remove_prefix(sample['txt']) + + return sample + + elif "text-to-image-2M" in sample['__url__']: + if "json" in sample and "prompt" in sample["json"]: + captions = sample["json"]["prompt"] + else: + print(f"sample has no json or prompt: {sample}") + captions = "" + + + sample['txt'] = captions + + return sample + + elif 'ai2d' in sample['__url__']: + key = sample['__key__'].split('/')[-1] + df_row = self.ai2d_caption[self.ai2d_caption['image'].astype(str) == key + '.png'] + if len(df_row) == 0: + print(f"No captions available for key {sample['__key__']}") + return sample + elif len(df_row) > 1: + # print(f"Multiple captions available for key {sample['__key__']}") + df_row = df_row.sample(1) + question = df_row['question'].values[0] + solution = df_row['solution'].values[0] + caption = ( + '<|start_header_id|>user<|end_header_id|>\n' + "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\n" + f"{question}\n" + '<|start_header_id|>assistant<|end_header_id|>\n' + f"{solution}" + ) + sample['txt'] = caption + return sample + + elif 'clevr' in sample['__url__']: + key = sample['__key__'].split('/')[-1] + df_row = self.clevr_caption[self.clevr_caption['image'].astype(str) == key + ".jpg"] + if len(df_row) == 0: + print(f"No captions available for key {sample['__key__']}") + return sample + elif len(df_row) > 1: + # print(f"Multiple captions available for key {sample['__key__']}") + df_row = df_row.sample(1) + question = df_row['question'].values[0] + solution = df_row['solution'].values[0] + caption = ( + '<|start_header_id|>user<|end_header_id|>\n' + "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\n" + f"{question}\n" + '<|start_header_id|>assistant<|end_header_id|>\n' + f"{solution}" + ) + sample['txt'] = caption + return sample + + elif 'docvqa' in sample['__url__']: + key = sample['__key__'].split('/')[-1] + df_row = self.docvqa_caption[self.docvqa_caption['image'].astype(str) == key + ".png"] + if len(df_row) == 0: + print(f"No captions available for key {sample['__key__']}") + return sample + elif len(df_row) > 1: + # print(f"Multiple captions available for key {sample['__key__']}") + df_row = df_row.sample(1) + question = df_row['question'].values[0] + solution = df_row['solution'].values[0] + caption = ( + '<|start_header_id|>user<|end_header_id|>\n' + "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\n" + f"{question}\n" + '<|start_header_id|>assistant<|end_header_id|>\n' + f"{solution}" + ) + sample['txt'] = caption + return sample + + elif 'geo' in sample['__url__']: + key = sample['__key__'].split('/')[-1] + df_row = self.geo_caption[self.geo_caption['image'].astype(str) == key + ".jpg"] + if len(df_row) == 0: + print(f"No captions available for key {sample['__key__']}") + return sample + elif len(df_row) > 1: + # print(f"Multiple captions available for key {sample['__key__']}") + df_row = df_row.sample(1) + question = df_row['question'].values[0] + solution = df_row['solution'].values[0] + caption = ( + '<|start_header_id|>user<|end_header_id|>\n' + "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\n" + f"{question}\n" + '<|start_header_id|>assistant<|end_header_id|>\n' + f"{solution}" + ) + sample['txt'] = caption + return sample + + + elif self.journeydb_caption is not None and sample['__key__'] in self.journeydb_caption: + captions_list = self.journeydb_caption[sample['__key__']] + if len(captions_list) == 0: + print(f"No captions available for key {sample['__key__']}") + return sample + sample['txt'] = random.sample(captions_list, 1)[0] + return sample + + else: + print(f"none exist sample: {sample}") + return sample + + @property + def train_dataset(self): + return self._train_dataset + + @property + def train_dataloader(self): + return self._train_dataloader + +# +++++ S2T/T2S Dataset Definition +++++ +class SpeechTextDataset(Dataset): + def __init__(self, dataset : str, subset : str, split : Optional[str] = None): + self.dataset_name = dataset + + if self.dataset_name == "gigaspeech": # subset is either "xs" or "xl" + self.hgf_dataset : datasets.Dataset = load_dataset("speechcolab/gigaspeech", subset, split=split) + + + elif self.dataset_name == "librispeech": + root_path = "/home/work/AIDAS/data/audio/LibriSpeech" + self.dataset_path = root_path + "/" + subset # subset is like "train-clean-100", etc + if split is not None: + warnings.warn(f"Split parameter '{split}' is provided but will not be used for LibriSpeech dataset.") + + # librispeech path processing + self.subdirs_path = sorted(list(glob(self.dataset_path + "/*/*"))) + self.subdirs_len = [len(glob(subdir + "/*.flac")) for subdir in self.subdirs_path] + self.subdirs_len_accum = list(itertools.accumulate(self.subdirs_len)) + + # handle wrong subset name + if len(self.subdirs_path) == 0: + raise ValueError(f"Invalid subset name '{subset}' for LibriSpeech dataset. Available subsets are: train-clean-100, train-clean-360") + + + elif self.dataset_name == "commonvoice": + self.commonvoice_path = "/home/work/AIDAS/data/audio/commonvoice/cv-corpus-22.0-2025-06-20/en" + if split is not None: + warnings.warn(f"Split parameter '{split}' is provided but will not be used for commonvoice dataset.") + + self.tsv = pd.read_csv(self.commonvoice_path + f"/{subset}.tsv", sep="\t", usecols=["path", "sentence"]) + + else: + raise ValueError(f"Unsupported dataset: {dataset}. Supported datasets are: gigaspeech, librispeech, commonvoice.") + + def __len__(self): + if self.dataset_name == "gigaspeech": + return len(self.hgf_dataset) + elif self.dataset_name == "librispeech": + return self.subdirs_len_accum[-1] + else: # commonvoice + return len(self.tsv) + + def __getitem__(self, idx): + audio_path : str; text : str + + if self.dataset_name == "gigaspeech": + sample = self.hgf_dataset[idx] + audio_path = sample["audio"]["path"] + text = sample["text"] + + + elif self.dataset_name == "librispeech": + # idx overflow + if idx >= self.subdirs_len_accum[-1]: + raise IndexError(f"Index {idx} is out of bounds for the dataset with length {len(self)}.") + + # audio_path (flac) + subdir_idx = bisect.bisect_right(self.subdirs_len_accum, idx) + flac_idx = idx - self.subdirs_len_accum[subdir_idx - 1] if subdir_idx > 0 else idx + audio_path = sorted(list(glob(self.subdirs_path[subdir_idx]+"/*.flac")))[flac_idx] + + # text + txt_path = glob(self.subdirs_path[subdir_idx]+"/*.txt") + assert len(txt_path) == 1, f"Expected one txt file in {self.subdirs_path[subdir_idx]}, found {len(txt_path)}" + with open(txt_path[0], "r") as f: + txt = f.readlines() + text = " ".join(txt[flac_idx].split(" ")[1:]) # rip off the header, e.g., "103-1240-0007 [TEXT]" + + + else: # commonvoice + audio_path = self.commonvoice_path + "/clips/" + self.tsv.iloc[idx]["path"] + text = self.tsv.iloc[idx]["sentence"] + + return {"audio_path": audio_path, "text": text} + +class MixedSpeechTextDataset(Dataset): + def __init__(self, dataset_configs: list): + """ + Initializes and combines multiple speech datasets. + + Args: + dataset_configs (list): A list of configuration dictionaries, + where each dict defines a dataset to load. + """ + self.dataset_metadata = [] + self.dataset_lengths = [] + self._sha1 = hashlib.sha1 + + # Iterate through the list of dataset configurations from the YAML file + for config in dataset_configs: + name = config['name'] + subset = config.get('subset') + split = config.get('split') + use_tokens = bool(config.get("use_precomputed_tokens", False)) + token_root = config.get("precomputed_tokens_root") + token_root_path = Path(token_root).expanduser() if token_root else None + + print(f"Initializing dataset: {name} (Subset: {subset}, Split: {split})") + + # --- Gigaspeech --- + if name == "gigaspeech": + hgf_dataset = datasets.load_dataset("speechcolab/gigaspeech", subset, split=split) + self.dataset_metadata.append({ + "name": name, + "data": hgf_dataset, + "use_precomputed_tokens": use_tokens and token_root_path is not None, + "precomputed_tokens_root": token_root_path, + }) + self.dataset_lengths.append(len(hgf_dataset)) + + # --- LibriSpeech --- + elif name == "librispeech": + root_path = "/home/work/AIDAS/data/audio/LibriSpeech" + dataset_path = f"{root_path}/{subset}" + if split is not None: + warnings.warn(f"Split parameter '{split}' is provided but will not be used for LibriSpeech.") + + subdirs_path = sorted(glob(f"{dataset_path}/*/*")) + if not subdirs_path: + raise ValueError(f"Invalid subset for LibriSpeech or path not found: {dataset_path}") + + subdirs_len = [len(glob(f"{subdir}/*.flac")) for subdir in subdirs_path] + subdirs_len_accum = list(itertools.accumulate(subdirs_len)) + + metadata = { + "name": name, + "subdirs_path": subdirs_path, + "subdirs_len_accum": subdirs_len_accum, + "use_precomputed_tokens": use_tokens and token_root_path is not None, + "precomputed_tokens_root": token_root_path, + } + self.dataset_metadata.append(metadata) + self.dataset_lengths.append(subdirs_len_accum[-1]) + + # --- Common Voice --- + elif name == "commonvoice": + commonvoice_path = "/home/work/AIDAS/data/audio/commonvoice/cv-corpus-22.0-2025-06-20/en" + if split is not None: + warnings.warn(f"Split parameter '{split}' is provided but will not be used for Common Voice.") + + tsv_path = f"{commonvoice_path}/{subset}.tsv" + tsv = pd.read_csv(tsv_path, sep="\t", usecols=["path", "sentence"]) + + metadata = { + "name": name, + "data_root": f"{commonvoice_path}/clips/", + "tsv": tsv, + "use_precomputed_tokens": use_tokens and token_root_path is not None, + "precomputed_tokens_root": token_root_path, + } + self.dataset_metadata.append(metadata) + self.dataset_lengths.append(len(tsv)) + + else: + raise ValueError(f"Unsupported dataset: {name}.") + + # Calculate cumulative lengths to map a global index to a specific dataset + self.cumulative_lengths = list(itertools.accumulate(self.dataset_lengths)) + # print(f"āœ… All datasets loaded for the SPEECH!. Total length: {self.__len__()} samples.") + + def __len__(self): + """Returns the total number of samples across all datasets.""" + return self.cumulative_lengths[-1] if self.cumulative_lengths else 0 + + def __getitem__(self, idx): + """ + Fetches a sample from the combined dataset. + + It first determines which dataset the global index `idx` belongs to, + calculates the local index within that dataset, and then retrieves the item. + """ + if idx >= self.__len__(): + raise IndexError(f"Index {idx} is out of bounds for the combined dataset with length {self.__len__()}.") + + # Find which dataset the index belongs to + dataset_idx = bisect.bisect_right(self.cumulative_lengths, idx) + + # Calculate the local index within that dataset + local_idx = idx - self.cumulative_lengths[dataset_idx - 1] if dataset_idx > 0 else idx + + metadata = self.dataset_metadata[dataset_idx] + dataset_name = metadata["name"] + dataset_length = self.dataset_lengths[dataset_idx] + + audio_path: str + text: str + audio_tokens: Optional[torch.Tensor] + + max_retry = 5 + retry = 0 + + while retry < max_retry: + try: + audio_tokens = None + + if dataset_name == "gigaspeech": + sample = metadata["data"][local_idx] + audio_path = sample["audio"]["path"] + text = sample["text"] + # Preprocess special words to punctuation + text = ( + text.replace(" ", ",") + .replace(" ", ".") + .replace(" ", "?") + .replace(" ", "!") + ) + + elif dataset_name == "librispeech": + # Find the specific subdirectory and file using the local index + subdir_idx = bisect.bisect_right(metadata["subdirs_len_accum"], local_idx) + flac_idx = local_idx - metadata["subdirs_len_accum"][subdir_idx - 1] if subdir_idx > 0 else local_idx + + subdir_path = metadata["subdirs_path"][subdir_idx] + audio_path = sorted(glob(f"{subdir_path}/*.flac"))[flac_idx] + + # Read the corresponding transcript + txt_path = glob(f"{subdir_path}/*.txt")[0] + with open(txt_path, "r") as f: + line = f.readlines()[flac_idx] + text = " ".join(line.strip().split(" ")[1:]) + + else: # commonvoice + row = metadata["tsv"].iloc[local_idx] + audio_path = metadata["data_root"] + row["path"] + text = row["sentence"] + # Preprocess lower case to upper case + text = text.upper() + + audio_tokens = self._maybe_load_precomputed_tokens(audio_path, metadata) + return { + "audio_path": audio_path, + "text": text, + "audio_tokens": audio_tokens, + } + + except Exception as exc: + print(f"[MixedSpeechTextDataset] Failed to load sample from '{dataset_name}' at local index {local_idx}: {exc!r}") + retry += 1 + if retry >= max_retry: + break + local_idx = random.randint(0, dataset_length - 1) + continue + + raise RuntimeError(f"Unable to fetch a valid sample from dataset '{dataset_name}' after {max_retry} retries.") + + def _maybe_load_precomputed_tokens(self, audio_path: str, metadata: dict) -> Optional[torch.Tensor]: + if not metadata.get("use_precomputed_tokens"): + return None + root: Optional[Path] = metadata.get("precomputed_tokens_root") + if root is None: + return None + if not root.exists(): + logger.warning("Precomputed token root missing: %s", root) + return None + key = os.path.abspath(audio_path) + digest = self._sha1(key.encode("utf-8")).hexdigest() + token_path = root / digest[:2] / digest[2:4] / f"{digest}.pt" + if not token_path.exists(): + logger.warning("Precomputed audio tokens not found: %s", token_path) + return None + try: + tokens = torch.load(token_path, map_location="cpu") + if isinstance(tokens, torch.Tensor): + return tokens.clone() + if isinstance(tokens, (list, tuple)): + return torch.tensor(tokens, dtype=torch.long) + logger.warning("Unexpected token format in %s (type=%s)", token_path, type(tokens)) + except Exception as exc: + logger.warning("Failed to load precomputed tokens %s: %s", token_path, exc) + return None + + +class Speech2SpeechDataset(Dataset): + """ + Mixed dataset of emova-sft and InstructS2S-200K. + Return value of __getitem__ indicates a pair of (user, assistant) message (single-turn). + Critically, the return type for emova_sft and instructs2s are different: + emova_sft: tuple[list[int], list[int], Any] + instructs2s: tuple[str, str, Optional[Any]] + So in the main training code, we need to handle both types of return values. + + Notes: + - Use `s2s_collate_fn` within DataLoader. + - For emova_sft, Tensor cannot be returned because padding is done later in the main training code. + + For reference: + Total samples: 496514 + emova-sft (speech-text): 73.6k + emova-sft (speech-image): 71.5k + InstructS2S-200K: 422856 + """ + def __init__(self, dataset_configs: list): + self.dataset_configs = dataset_configs # currently this arg is not used + + ## emova-sft (text + image splits) + emova_sft_text = load_dataset("Emova-ollm/emova-sft-4m", "emova-speech-text-en", split='train') + emova_sft_image = load_dataset("Emova-ollm/emova-sft-4m", "emova-speech-image-en", split='train') + + def _maybe_cast_image_columns(ds): + for column in ("image", "images"): + if column in ds.column_names: + try: + ds = ds.cast_column(column, datasets.Image(decode=False)) + except Exception: + # Column may already be raw bytes/str; ignore and keep as-is + pass + return ds + + emova_sft_text = _maybe_cast_image_columns(emova_sft_text) + emova_sft_image = _maybe_cast_image_columns(emova_sft_image) + + def emova_sft_preprocess_batch(batch): + # Extract conversations from the batch + conversations_list = batch['conversations'] + + usr_ids_list = [] + asst_ids_list = [] + images_list = [] + + def normalize_emova_ids(ids: str) -> list[int]: + unit_numbers = ids.replace('<|speech_', '').replace('|>', ' ').strip() + unit_ids = [int(unit) for unit in unit_numbers.split(" ")] + return unit_ids + + # Process each conversation in the batch + for conversations in conversations_list: + usr_raw = conversations[0]['value'] + asst_raw = conversations[1]['value'] + + usr_ids: str = usr_raw.split("\n\nuser question speech:")[-1].strip() + asst_ids: str = json.loads(asst_raw)['assistant response speech'].strip() + + usr_ids_list.append(normalize_emova_ids(usr_ids)) + asst_ids_list.append(normalize_emova_ids(asst_ids)) + + raw_images = ( + batch.get('image') + or batch.get('images') + or batch.get('image_base64') + or [None] * len(conversations_list) + ) + + if not isinstance(raw_images, (list, tuple)): + raw_images = [raw_images] * len(conversations_list) + else: + raw_images = list(raw_images) + + if len(raw_images) != len(conversations_list): + # Align lengths by padding/truncating with None without decoding payloads + adjusted = raw_images[:len(conversations_list)] + if len(adjusted) < len(conversations_list): + adjusted.extend([None] * (len(conversations_list) - len(adjusted))) + raw_images = adjusted + + images_list.extend(raw_images) + + # Return a dictionary with lists of processed data + return { + "usr_ids": usr_ids_list, + "asst_ids": asst_ids_list, + "image": images_list, + } + + self.emova_sft_text = emova_sft_text.map( + emova_sft_preprocess_batch, + batched=True, + batch_size=1024, + remove_columns=['conversations'], + desc="Processing emova-sft (text)", + num_proc=16 + ) + + self.emova_sft_image = emova_sft_image.map( + emova_sft_preprocess_batch, + batched=True, + batch_size=1024, + remove_columns=['conversations'], + desc="Processing emova-sft (image)", + num_proc=16 + ) + + self._emova_text_len = len(self.emova_sft_text) + self._emova_image_len = len(self.emova_sft_image) + self._emova_total_len = self._emova_text_len + self._emova_image_len + + ## InstructS2S-200K (with caching) + instructs2s_rootdir = "/home/work/AIDAS/data/InstructS2S-200K/en/wav" + self.instructs2s_wav_pair_paths = [] + + pairs_txt = os.path.join(instructs2s_rootdir, "pairs.txt") + if os.path.isfile(pairs_txt): + with open(pairs_txt, "r") as f: + for line in tqdm(f, desc="Loading InstructS2S-200K paths from cached file"): + line = line.strip() + if not line: + continue + parts = line.split() + if len(parts) >= 2: + self.instructs2s_wav_pair_paths.append((parts[0], parts[1])) + else: + instructs2s_wav_dirs = [p for p in glob(os.path.join(instructs2s_rootdir, "*")) if os.path.isdir(p)] + # Walk each directory and collect (user, assistant) wav pairs + for dir_path in tqdm(instructs2s_wav_dirs, desc="Processing instructs2s-200k"): + dir_name = os.path.basename(dir_path) + k = 1 + while True: + user_wav = os.path.join(dir_path, f"{dir_name}-{k}-user.wav") + assistant_wav = os.path.join(dir_path, f"{dir_name}-{k}-assistant.wav") + if os.path.isfile(user_wav) and os.path.isfile(assistant_wav): + self.instructs2s_wav_pair_paths.append((user_wav, assistant_wav)) + k += 1 + continue + break + with open(pairs_txt, "w") as f: + for u, a in self.instructs2s_wav_pair_paths: + f.write(f"{u} {a}\n") + + ## Mixed dataset (ordered) + self.mixed_dataset = [self.emova_sft_text, self.emova_sft_image, self.instructs2s_wav_pair_paths] + + def __len__(self): + return sum([len(dataset) for dataset in self.mixed_dataset]) + + def __getitem__(self, idx) -> Union[tuple[list[int], list[int], Any], tuple[str, str, Optional[Any]]]: + if idx < self._emova_text_len: # emova_sft text split + sample = self.emova_sft_text[idx] + elif idx < self._emova_total_len: # emova_sft image split + sample = self.emova_sft_image[idx - self._emova_text_len] + else: # instructs2s + local_idx = idx - self._emova_total_len + usr_wav, asst_wav = self.instructs2s_wav_pair_paths[local_idx] + return usr_wav, asst_wav, None # tuple[str, str, Optional[image]]; wav file paths + + usr_ids = sample['usr_ids'] + asst_ids = sample['asst_ids'] + image = sample.get('image') + return usr_ids, asst_ids, image # tuple[list[int], list[int], image] + +def s2s_collate_fn(batch): + """ + Collate function for Speech2SpeechDataset. + """ + emova_data = [] + instructs2s_data = [] + + for item in batch: + if isinstance(item[0], list): # emova_sft: tuple[list[int], list[int]] + emova_data.append(item) + else: # instructs2s: tuple[str, str] + instructs2s_data.append(item) + + return { + 'emova_sft': emova_data, + 'instructs2s': instructs2s_data, + } + + +class VideoCaptionDataset(Dataset): + def __init__( + self, + transform, + tokenizer, + max_seq_length: int, + + resolution: int = 256, + panda70m_path = "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m", + openvid1m_path = "/home/work/AIDAS/data/video/openvid1m/video", + webvid10m_path = "/home/work/AIDAS/data/video/webvid10m", + llavavid_path = "/home/work/AIDAS/data/video/LLaVA-Video-178K", + + dataset_name = "openvid1m", + llavavid_local_files_only: bool = False, + llavavid_skip_configs: Optional[Sequence[str]] = None, + llavavid_skip_video_patterns: Optional[Sequence[str]] = None, + + sample_method='uniform', + num_frames: int = 8, + vq_model=None, + ): + + available_datasets = ['panda70m', 'openvid1m', 'webvid10m', 'llavavid'] + if dataset_name not in available_datasets: + raise ValueError(f"Invalid dataset name: {dataset_name}. Available datasets: {available_datasets}") + + self.max_seq_length = max_seq_length + self.transform = transform + self.vq_model = vq_model + self.tokenizer = tokenizer + self.resolution = resolution + self.sample_method = sample_method + self.dataset_name = dataset_name + self.num_frames = num_frames + self.llavavid_local_files_only = llavavid_local_files_only + self.llavavid_skip_configs = set(llavavid_skip_configs or []) + self.llavavid_skip_video_patterns = tuple(llavavid_skip_video_patterns or []) + self.caption_prompt = V2T_INSTRUCTION + self.caption_prompt = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in self.caption_prompt] + + self.webvid10m_path = webvid10m_path + + if dataset_name == 'panda70m': + self.vid_data = self._collect_panda70m(panda70m_path) + self.dataset_root = panda70m_path + elif dataset_name == 'webvid10m': + self.vid_data = self._collect_webvid10m(webvid10m_path) + self.dataset_root = webvid10m_path + elif dataset_name == 'openvid1m': + self.vid_data = self._collect_openvid1m(openvid1m_path) + self.dataset_root = openvid1m_path + elif dataset_name == 'llavavid': + self.vid_data = self._collect_llavavid(llavavid_path) + self.dataset_root = Path(llavavid_path) + self.llavavid_video_root = Path(llavavid_path) + + else: + raise ValueError(f"Invalid dataset name: {dataset_name}. Available datasets: panda70m, webvid10m") + + def _get_caption_prompt(self): + """ + Get a random caption prompt from the list of caption prompts. + """ + return np.random.choice(self.caption_prompt) + + def _tokenize(self, text): + if self.tokenizer is not None: + input_ids = self.tokenizer( + text, + truncation=True, + max_length=2 * self.max_seq_length, + padding=False, + return_tensors="pt" + )[0] + + if len(input_ids) > self.max_seq_length: + return None + else: + return input_ids + else: + raise ValueError("Tokenizer is not provided.") + + def _collect_webvid10m(self, root_path): + + print("Loading videos from WebVid10m dataset...") + csv_path = osp.join(root_path, "webvid-10M-train.csv") + + webvid_pd = pd.read_csv(csv_path) + self.dataset_length = len(webvid_pd) + print(f"{len(webvid_pd)} videos has been loaded.") + + return webvid_pd + + def _collect_panda70m(self, root_path): + video_caption_pairs = [] + subdirs = sorted(os.listdir(root_path)) + + print("Loading videos from panda70m dataset...") + for subdir in subdirs: + full_subdir = os.path.join(root_path, subdir) + if not os.path.isdir(full_subdir): + continue + + video_paths = glob(os.path.join(full_subdir, "*.mp4")) + for video_path in video_paths: + caption_path = video_path.replace(".mp4", ".txt") + if os.path.exists(caption_path): + with open(caption_path, 'r') as f: + caption = f.read().strip() + prompt = self._get_caption_prompt() + video_caption_pairs.append({ + "video": video_path, + "caption": prompt + caption + }) + print(f"{len(video_caption_pairs)} videos has been loaded.") + + return video_caption_pairs + + def _collect_openvid1m(self, root_path): + csv_path = osp.join(root_path, "OpenVid-1M.csv") + openvid_pd = pd.read_csv(csv_path) + self.dataset_length = len(openvid_pd) + print(f"{len(openvid_pd)} videos has been loaded.") + + return openvid_pd + + + def _collect_llavavid( + self, + root_path="lmms-lab/LLaVA-Video-178K", + cache_dir="/home/work/AIDAS/huggingface/datasets" + ): + """ + Collect all available (and locally cached) subsets of the LLaVA-Video-178K dataset. + Handles both on-disk exports (each config stored as subfolders of splits) and remote configs. + Returns a single flattened HuggingFace Dataset that concatenates every successfully loaded config. + """ + DATASET_NAME = root_path + + local_root = Path(DATASET_NAME) + configs: list[str] + using_local_dirs = local_root.exists() + + configs = [] + if using_local_dirs: + for p in sorted(local_root.iterdir()): + if not p.is_dir(): + continue + if p.name.startswith("."): + continue + split_exists = any((p / split_name).exists() for split_name in ("open_ended", "caption", "multi_choice")) + if not split_exists: + continue + configs.append(p.name) + if not configs: + using_local_dirs = False + + if not configs: + try: + configs = get_dataset_config_names(DATASET_NAME) + using_local_dirs = False + except Exception as e: + raise RuntimeError(f"Failed to fetch configs for {DATASET_NAME}: {e}") + + skip_configs = getattr(self, "llavavid_skip_configs", set()) + if skip_configs: + existing = [cfg for cfg in configs if cfg in skip_configs] + if existing: + print(f"LLaVA-Vid: skipping configs {existing}") + configs = [cfg for cfg in configs if cfg not in skip_configs] + + if not configs: + raise RuntimeError("All LLaVA-Video configs were skipped; nothing left to load.") + + def _add_config_column(dataset: HFDataset, cfg_name: str, row_count: int): + """Attach the originating config name so downstream can locate videos.""" + if dataset is None or not cfg_name: + return dataset + if "llavavid_config" in dataset.column_names: + return dataset + return dataset.add_column("llavavid_config", [cfg_name] * row_count) + + def _flatten_dataset(ds_obj, label: str, cfg_name: str): + """Convert DatasetDicts into a single Dataset and report the row count.""" + if ds_obj is None: + return None, 0 + if isinstance(ds_obj, DatasetDict): + splits = [split for split in ds_obj.values()] + if not splits: + print(f"Skipping {label}: dataset dict has no splits.") + return None, 0 + total_rows = sum(len(split) for split in splits) + if len(splits) == 1: + return splits[0], total_rows + try: + merged = concatenate_datasets(splits) + except Exception as merge_err: + print(f"Skipping {label}: failed to concatenate splits: {merge_err}") + return None, 0 + dataset = merged + else: + dataset = ds_obj + try: + total_rows = len(dataset) + except Exception as len_err: + print(f"Skipping {label}: unable to compute dataset length ({len_err}).") + return None, 0 + + dataset = _add_config_column(dataset, cfg_name, total_rows) + return dataset, total_rows + + def _load_local_config(cfg_name: str): + """Attempt to read a single config from disk, handling split sub-directories if needed.""" + cfg_root = local_root / cfg_name + if not cfg_root.exists(): + return None, 0 + + # First try loading the directory directly (Dataset or DatasetDict exports). + try: + ds_direct = load_from_disk(str(cfg_root)) + except Exception as direct_err: + print(f"Failed to load config {cfg_name} via load_from_disk: {direct_err}.") + else: + ds_flat, ds_count = _flatten_dataset(ds_direct, cfg_name, cfg_name) + if ds_flat is not None and ds_count > 0: + return ds_flat, ds_count + + # Fallback: iterate over split sub-directories (caption/open_ended/multi_choice, etc.). + split_dirs = [p for p in sorted(cfg_root.iterdir()) if p.is_dir()] + if not split_dirs: + return None, 0 + + split_datasets = [] + for split_dir in split_dirs: + try: + split_ds = load_from_disk(str(split_dir)) + except Exception as split_err: + print(f"Skipping {cfg_name}/{split_dir.name}: {split_err}") + continue + split_datasets.append(split_ds) + + if not split_datasets: + return None, 0 + + split_total = sum(len(split_ds) for split_ds in split_datasets) + if len(split_datasets) == 1: + dataset = split_datasets[0] + else: + try: + dataset = concatenate_datasets(split_datasets) + except Exception as merge_err: + print(f"Skipping {cfg_name}: failed to concatenate split datasets: {merge_err}") + return None, 0 + + dataset = _add_config_column(dataset, cfg_name, split_total) + return dataset, split_total + + datasets_loaded = [] + total_count = 0 + + for cfg in configs: + ds = None + cfg_count = 0 + + if using_local_dirs: + ds, cfg_count = _load_local_config(cfg) + + if ds is None or cfg_count == 0: + download_cfg = None + if self.llavavid_local_files_only: + download_cfg = DownloadConfig(local_files_only=True) + try: + remote_ds = load_dataset( + DATASET_NAME, + name=cfg, + cache_dir=cache_dir, + verification_mode="no_checks", + download_config=download_cfg, + ) + except Exception as remote_err: + print(f"Skipping {cfg}: {remote_err}") + continue + + ds, cfg_count = _flatten_dataset(remote_ds, cfg, cfg) + if ds is None or cfg_count == 0: + print(f"Skipping {cfg}: dataset empty after flattening.") + continue + + datasets_loaded.append(ds) + total_count += cfg_count + + if not datasets_loaded: + raise RuntimeError("No valid configs could be loaded!") + + if len(datasets_loaded) == 1: + global_dataset = datasets_loaded[0] + else: + try: + global_dataset = concatenate_datasets(datasets_loaded) + except Exception as merge_err: + print(f"Failed to concatenate configs in one step: {merge_err}. Trying pairwise concatenation.") + try: + combined = datasets_loaded[0] + for ds_next in datasets_loaded[1:]: + combined = concatenate_datasets([combined, ds_next]) + global_dataset = combined + except Exception as pair_err: + raise RuntimeError(f"Unable to merge LLaVA-Video configs: {pair_err}") from pair_err + + # Filter out samples whose video path matches known-bad patterns (e.g., missing shareVideoGPTV frames) + skip_patterns = getattr(self, "llavavid_skip_video_patterns", tuple()) + if skip_patterns: + def _matches_skip(entry: dict[str, Any]) -> bool: + video_entry = entry.get("video") + if not isinstance(video_entry, str): + return False + return any(pattern in video_entry for pattern in skip_patterns) + + def _filter_dataset(ds_obj): + if isinstance(ds_obj, list): + filtered_list = [] + removed_total = 0 + for item in ds_obj: + filtered_item, removed_item = _filter_dataset(item) + removed_total += removed_item + if filtered_item is None: + continue + filtered_list.append(filtered_item) + return filtered_list, removed_total + elif isinstance(ds_obj, HFDataset): + before = len(ds_obj) + filtered = ds_obj.filter(lambda ex: not _matches_skip(ex)) + removed = before - len(filtered) + return filtered, removed + elif isinstance(ds_obj, dict): + return (None, 1) if _matches_skip(ds_obj) else (ds_obj, 0) + else: + return ds_obj, 0 + + global_dataset, removed_samples = _filter_dataset(global_dataset) + if removed_samples > 0: + total_count -= removed_samples + print(f"LLaVA-Vid: skipped {removed_samples} samples matching patterns {skip_patterns}.") + + print(f"LLaVA-Vid: {len(datasets_loaded)} configs loaded.") + print(f"LLaVA-Vid: {total_count:,} total samples loaded.") + + self.dataset_length = total_count + return global_dataset + + def __len__(self): + return len(self.vid_data) + + def __getitem__(self, idx): + max_try_count = 50 + + for try_count in range(max_try_count): + try: + data = self._sample_data(idx) + except Exception as exc: + logger.warning( + "VideoCaptionDataset failed to fetch index %s on attempt %s/%s: %s", + idx, + try_count + 1, + max_try_count, + exc, + ) + idx = random.randint(0, self.dataset_length - 1) + continue + + if data is not None: + return { + "video": data["video"], + "caption": data["caption"], + } + + idx = random.randint(0, self.dataset_length - 1) + + logger.warning( + "VideoCaptionDataset exhausted %s attempts without a valid sample; returning None.", + max_try_count, + ) + return None + + + def _sample_data_webvid10m(self): + store_path = osp.join(self.webvid10m_path, "video_store") + + row = self.video_caption_pairs['webvid10m'].sample(1).iloc[0] + video_id = str(row["videoid"]) + url = row["contentUrl"] + caption = row["name"] + + video_path = osp.join(store_path, f"{video_id}.mp4") + if not osp.exists(video_path): # not downloaded yet + download_video_url(url, video_path) + + # print(video_id) + # print(_whoami_str()) + + return video_path, caption + + + + def _sample_data(self, idx): + if self.dataset_name == 'webvid10m': + # currently randomly sample from the dataset + video_path, caption = self._sample_data_webvid10m() + elif self.dataset_name == 'panda70m': + raise NotImplementedError("Panda70m is not implemented yet.") + # video_path, caption = self._sample_data_panda70m() + elif self.dataset_name == 'openvid1m': + data_row = self.vid_data.iloc[idx] + video_path = osp.join(self.dataset_root, "video", data_row["video"]) + caption = data_row["caption"] + elif self.dataset_name == 'llavavid': + data_row = self.vid_data[idx] + video_entry = data_row['video'] + cfg_name = data_row.get('llavavid_config') if isinstance(data_row, dict) else None + caption = data_row['conversations'] # this is a list of turns in llavavid + + resolved_video_path = None + if isinstance(video_entry, str): + candidate_paths = [] + video_path_obj = Path(video_entry) + if video_path_obj.is_absolute() and video_path_obj.exists(): + resolved_video_path = video_path_obj + else: + if hasattr(self, "llavavid_video_root"): + base_root = Path(self.llavavid_video_root) + if cfg_name: + candidate_paths.append(base_root / cfg_name / video_entry) + candidate_paths.append(base_root / video_entry) + # Also allow treating the stored value as relative to current dir. + candidate_paths.append(Path(video_entry)) + + for candidate in candidate_paths: + if candidate.exists(): + resolved_video_path = candidate + break + + if resolved_video_path is None: + logger.warning( + "LLaVA-Video sample missing video file: %s (config=%s)", + video_entry, + cfg_name, + ) + return None + + if resolved_video_path.suffix.lower() == ".mkv": + logger.warning( + "LLaVA-Video skipping MKV file: %s (config=%s)", + resolved_video_path, + cfg_name, + ) + return None + + video_path = str(resolved_video_path) + else: + raise ValueError(f"Invalid dataset name: {self.dataset_name}. Available datasets: panda70m, webvid10m, openvid1m") + + try: + frames = load_video_mp4( + video_path=video_path, + sample_method=self.sample_method, + num_frames=self.num_frames, + resolution=self.resolution, + transform=self.transform, + strict=False, + ) + except Exception as exc: + logger.warning( + "LLaVA-Video sample failed to load (%s): %s", + video_path, + exc, + ) + return None + if frames is None: + logger.warning( + "LLaVA-Video sample timed out while reading frames (%s); skipping sample.", + video_path, + ) + return None + + return { + "video": frames, # torch tensor (T, C, H, W) + "caption": caption # input_ids (seq_len); str + } + +def download_video_url(url: str, save_path, timeout=10, max_retries=3) -> bool: + for attempt in range(1, max_retries + 1): + try: + with requests.get(url, stream=True, timeout=timeout) as r: + r.raise_for_status() + with open(save_path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + return True # Success + + except Exception as e: + print(f"[Attempt {attempt}/{max_retries}] Download failed: {e}") + if attempt < max_retries: + sleep_time = 2 ** (attempt - 1) # exponential backoff: 1,2,4,8,... + time.sleep(sleep_time) + else: + return False # all attempts failed + + return False + +def load_video_mp4( + video_path, + sample_method: str = 'uniform', + num_frames: int = 8, + resolution: int = 256, + transform=None, + *, + per_frame_timeout: float = 1.5, + read_retry_interval: float = 0.05, + strict: bool = True, + ): + """ + Load video frames and return them as a list of PIL images. + + Args: + video_path: Path to the video file. + sample_method: Sampling method, 'uniform' or 'random'. + num_frames: Number of frames to sample from the video. + per_frame_timeout: Max seconds to block while seeking/reading a frame. + read_retry_interval: Delay between read retries while waiting for a frame. + strict: When False, return None on timeout/seek failure instead of raising. + + Returns: + List[Image.Image] | None (if strict=False and a timeout/seek failure occurs) + """ + + with open(os.devnull, "w") as devnull, contextlib.redirect_stderr(devnull): + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Could not open video file {video_path}") + + if per_frame_timeout <= 0: + per_frame_timeout = 0.1 + if read_retry_interval <= 0: + read_retry_interval = 0.01 + + def _read_frame_with_timeout(frame_index: Optional[int] = None): + deadline = time.monotonic() + per_frame_timeout + attempts = 0 + while True: + if frame_index is not None: + cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index)) + ret, frame = cap.read() + if ret and frame is not None: + return frame + attempts += 1 + if time.monotonic() >= deadline: + return None + time.sleep(min(read_retry_interval, max(deadline - time.monotonic(), 0.0))) + + try: + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + except Exception: + frame_count = -1 + + if frame_count is None or frame_count <= 0: + # Fallback: attempt to read sequentially but stop early on failure + frames = [] + try: + while len(frames) < num_frames: + frame = _read_frame_with_timeout() + if frame is None: + break + frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) + finally: + cap.release() + + if len(frames) < num_frames: + msg = f"Video {video_path} has insufficient frames ({len(frames)})." + if strict: + raise ValueError(msg) + logger.warning("%s Skipping sample.", msg) + return None + selected = frames + else: + if frame_count < num_frames: + cap.release() + msg = f"Video {video_path} has insufficient frames ({frame_count})." + if strict: + raise ValueError(msg) + logger.warning("%s Skipping sample.", msg) + return None + + if sample_method == 'uniform': + indices = np.linspace(0, frame_count - 1, num_frames).astype(int) + elif sample_method == 'random': + indices = np.sort(np.random.choice(frame_count, num_frames, replace=False)) + else: + cap.release() + raise ValueError(f"Sampling method {sample_method} not supported.") + + selected = [] + try: + for idx in indices: + frame = _read_frame_with_timeout(idx) + if frame is None: + msg = ( + f"Timed out ({per_frame_timeout:.2f}s) seeking frame {idx} in {video_path}" + ) + if strict: + raise TimeoutError(msg) + logger.warning("%s. Skipping sample.", msg) + return None + selected.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) + finally: + cap.release() + + sampled_frames = [] + for frame in selected: + if transform: + frame = transform(frame, resolution=resolution) + sampled_frames.append(frame) + + return sampled_frames + + +class VideoSpeechDataset(Dataset): + """Loads paired video clips and speech audio paths or pre-tokenized speech.""" + + def __init__( + self, + *, + transform=None, + resolution: int = 256, + num_frames: int = 8, + video_root: str = "/home/work/AIDAS/data/video/openvid1m/video/video", + audio_root: str = "/home/work/AIDAS/data/video-speech", + speech_dir_name: str = "openvid-speech-trunc", + index_path: str = "/home/work/AIDAS/data/video-speech/openvid-speech.csv", + sample_method: str = "uniform", + precomputed_tokens_root: Optional[str] = None, + ) -> None: + self.transform = transform + self.resolution = resolution + self.num_frames = num_frames + self.sample_method = sample_method or "uniform" + if self.sample_method not in {"uniform", "random"}: + logger.warning("Unknown sample_method '%s', defaulting to 'uniform'", self.sample_method) + self.sample_method = "uniform" + + self.video_root = Path(video_root).expanduser().resolve() + audio_base = Path(audio_root).expanduser() + if speech_dir_name: + audio_base = audio_base / speech_dir_name + self.audio_root = audio_base.resolve() + + self.index_path = Path(index_path).expanduser().resolve() + if not self.index_path.exists(): + raise FileNotFoundError(f"VideoSpeechDataset index not found: {self.index_path}") + + self.precomputed_tokens_root = ( + Path(precomputed_tokens_root).expanduser().resolve() + if precomputed_tokens_root + else None + ) + if self.precomputed_tokens_root is not None and not self.precomputed_tokens_root.exists(): + logger.warning( + "Precomputed speech token root %s missing; falling back to raw audio paths.", + self.precomputed_tokens_root, + ) + self.precomputed_tokens_root = None + + self._samples: list[tuple[Path, Path]] = [] + self._token_cache: Dict[str, torch.Tensor] = {} + self._token_cache_limit = 4096 + + self._load_index() + if not self._samples: + raise RuntimeError(f"VideoSpeechDataset found no valid samples in {self.index_path}") + + def _load_index(self) -> None: + missing = 0 + with self.index_path.open("r", newline="") as csvfile: + reader = csv.reader(csvfile) + for row in reader: + if not row: + continue + base = row[0].strip() + if not base: + continue + if base.lower().endswith(".wav"): + base = base[:-4] + video_path = self.video_root / f"{base}.mp4" + audio_path = self.audio_root / f"{base}.wav" + if not video_path.is_file() or not audio_path.is_file(): + missing += 1 + continue + self._samples.append((video_path, audio_path)) + if missing: + logger.info( + "VideoSpeechDataset skipped %d entries missing media (index=%s)", + missing, + self.index_path, + ) + + def __len__(self) -> int: + return len(self._samples) + + def _transform_frame(self, image: Image.Image, resolution: int) -> torch.Tensor: + if self.transform is None: + return utils_image_transform(image, resolution) + try: + return self.transform(image, resolution=resolution) + except TypeError: + return self.transform(image) + + def _resolve_token_path(self, audio_path: Path) -> Optional[Path]: + if self.precomputed_tokens_root is None: + return None + digest = hashlib.sha1(os.path.abspath(str(audio_path)).encode("utf-8")).hexdigest() + return self.precomputed_tokens_root / digest[:2] / digest[2:4] / f"{digest}.pt" + + def _get_precomputed_tokens(self, audio_path: Path) -> Optional[torch.Tensor]: + cache_key = os.path.abspath(str(audio_path)) + cached = self._token_cache.get(cache_key) + if cached is not None: + return cached.clone() + + token_path = self._resolve_token_path(audio_path) + if token_path is None or not token_path.exists(): + return None + try: + tokens = torch.load(token_path, map_location="cpu") + except Exception as exc: + logger.warning("Failed to load precomputed speech tokens %s: %s", token_path, exc) + return None + if not isinstance(tokens, torch.Tensor): + return None + tokens = tokens.to(dtype=torch.long, copy=False) + if len(self._token_cache) < self._token_cache_limit: + self._token_cache[cache_key] = tokens + return tokens.clone() + + def _prepare_speech_entry(self, audio_path: Path): + tokens = self._get_precomputed_tokens(audio_path) + if tokens is not None: + return tokens + return str(audio_path) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + video_path, audio_path = self._samples[idx] + frames = load_video_mp4( + str(video_path), + sample_method=self.sample_method, + num_frames=self.num_frames, + resolution=self.resolution, + transform=self._transform_frame, + ) + speech_entry = self._prepare_speech_entry(audio_path) + return { + "video": frames, + "speech": speech_entry, + } + +class TextImageInterleavedDataset: + """ + HF-backed dataset that yields rows of: + { + "image_paths": [str, ...], # absolute paths (no decoding) + "user_text": str, + "assistant_text": str, + } + """ + + def __init__( + self, + *, + configs: Union[str, Sequence[str], None] = None, # default: all configs + split: str = "train", + data_root: str = "/home/work/AIDAS/data/TIGER-Lab/Mantis-Instruct", + max_images: Optional[int] = None, + filter_empty: bool = True, + resolution: int = 256, + # sampling controls + per_config_fraction: float = 1/7, # ← sample 1/7 PER CONFIG + sample_seed: int = 42, + # kept for compatibility, not used in this 1/7-per-config version + max_samples: Optional[int] = 1_000_000, + local_data_root: Optional[str] = None, + local_data_files: Optional[Dict[str, Any]] = None, + local_files_only: bool = False, + ): + self.data_root = data_root + self.split = self._normalize_split(split) + self.max_images = max_images + self.filter_empty = filter_empty + self.resolution = resolution + self.local_data_root = local_data_root + self.local_data_files = local_data_files or {} + self._download_config = DownloadConfig(local_files_only=True) if local_files_only else None + + # cache transforms + self._tfm_crop = transforms.Compose([ + transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop((resolution, resolution)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), + ]) + self._tfm_squash = transforms.Compose([ + transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]), + ]) + + # ---- resolve configs ---- + if configs is None or configs == "all": + cfgs = self._resolve_configs_from_local() + if not cfgs: + cfgs = sorted(get_dataset_config_names("TIGER-Lab/Mantis-Instruct")) + elif isinstance(configs, str): + cfgs = [configs] + else: + cfgs = list(configs) + self.configs = cfgs + + rng = np.random.default_rng(sample_seed) + per_cfg_ds: List[HFDataset] = [] + + acc = Accelerator() + + for cfg in cfgs: + base_ds = self._load_base_dataset(cfg, acc) + if base_ds is None: + continue + # --- SAMPLE 1/7 OF BASE ROWS *PER CONFIG* (before any map/expansion) --- + n = len(base_ds) + if n == 0: + continue + k = max(1, int(np.floor(n * per_config_fraction))) + # reproducible uniform sample without replacement + sel_idx = rng.choice(n, size=k, replace=False) + base_ds = base_ds.select(list(sel_idx)) + + # locate image dir for this (cfg, split) + img_dir = self._resolve_img_dir(cfg, self.split) + if img_dir is None: + raise FileNotFoundError(f"No image dir for config='{cfg}', split='{self.split}'") + + # (1) attach constants + def add_const_cols(batch): + m = len(next(iter(batch.values()))) if batch else 0 + return {"config": [cfg]*m, "img_dir": [img_dir]*m} + + ds = base_ds.map(add_const_cols, batched=True) + + # (2) normalize image column → absolute string paths + image_key = self._guess_image_key(ds.column_names) + + def make_abs_paths(batch): + bases = batch["img_dir"] # list[str] per row + rels = batch[image_key] # per-row: list[dict]|dict|list[str]|str|None + + def dict_to_rel(d: Dict[str, Any]) -> Optional[str]: + # typical HF Image: {"path": "...", "bytes": ...} + for k in ("path", "file_name", "filepath", "image_path", "name"): + v = d.get(k) + if isinstance(v, str) and v: + return v + # nested + img = d.get("image") + if isinstance(img, dict): + v = img.get("path") + if isinstance(v, str) and v: + return v + return None + + out_paths = [] + for base, r in zip(bases, rels): + # normalize r → list[str] + if r is None: + row = [] + elif isinstance(r, str): + row = [r] + elif isinstance(r, dict): + s = dict_to_rel(r) + row = [s] if s else [] + elif isinstance(r, list): + tmp = [] + for x in r: + if isinstance(x, str): + tmp.append(x) + elif isinstance(x, dict): + s = dict_to_rel(x) + if s: + tmp.append(s) + row = tmp + else: + row = [] + + # join to absolute (keep absolute if already) + abs_paths = [p if os.path.isabs(p) else os.path.join(base, p) for p in row if isinstance(p, str)] + + # cap if requested + if self.max_images is not None and len(abs_paths) > self.max_images: + abs_paths = abs_paths[: self.max_images] + + out_paths.append(abs_paths) + + return {"image_paths": out_paths} + + ds = ds.map(make_abs_paths, batched=True) + + # (3) expand conversation: one row per (user → assistant) turn + conv_key = "conversation" + + def expand_turns(batch): + image_paths_list = batch["image_paths"] + conversations = batch.get(conv_key, [[]] * len(image_paths_list)) + + out_img_paths, out_user, out_assistant = [], [], [] + + for img_paths, conv in zip(image_paths_list, conversations): + conv = conv or [] + # walk adjacent pairs + i = 0 + while i < len(conv) - 1: + a, b = conv[i], conv[i + 1] + if (isinstance(a, dict) and isinstance(b, dict) + and a.get("role") == "user" and b.get("role") == "assistant"): + user_text = (a.get("content") or "").strip() + assistant_text = (b.get("content") or "").strip() + if (not self.filter_empty) or assistant_text: + out_img_paths.append(img_paths) + out_user.append(user_text) + out_assistant.append(assistant_text) + i += 2 + else: + i += 1 + + return { + "image_paths": out_img_paths, + "user_text": out_user, + "assistant_text": out_assistant, + } + + ds = ds.map(expand_turns, batched=True, remove_columns=ds.column_names) + + if self.filter_empty: + ds = ds.filter(lambda e: bool(e["assistant_text"])) + + per_cfg_ds.append(ds) + + if not per_cfg_ds: + raise ValueError("Empty dataset after per-config sampling and preprocessing.") + + self.dataset = concatenate_datasets(per_cfg_ds) if len(per_cfg_ds) > 1 else per_cfg_ds[0] + self.dataset = self.dataset.with_format("python") + print(f"[HF Dataset] per-config 1/7 sampled; configs={self.configs}, split='{self.split}', rows={len(self.dataset)}") + + # ---- public API ---- + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + start_idx = idx + attempts = 0 + max_attempts = 10 + + while attempts < max_attempts: + ex = self.dataset[idx] + + text = ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{ex['user_text']}\n" + "<|start_header_id|>assistant<|end_header_id|>\n" + f"{ex['assistant_text']}" + ) + + paths = ex["image_paths"] + imgs: list[torch.Tensor] = [] + for path in paths: + img = self._load_and_transform_one(path) + if img is not None: + imgs.append(img) + + if imgs: + return { + "images": imgs, + "text": text, + } + + attempts += 1 + idx = (idx + 1) % len(self.dataset) + if idx == start_idx: + break + + raise RuntimeError("TextImageInterleavedDataset: no valid images found after retries.") + + # ---- helpers ---- + @staticmethod + def _normalize_split(split: str) -> str: + s = split.lower() + return {"val": "validation", "dev": "validation"}.get(s, s) + + def _resolve_img_dir(self, cfg: str, split: str) -> Optional[str]: + # Typical local layout: + # {data_root}/{cfg}/{split}_images + # {data_root}/{cfg}/images + cand1 = os.path.join(self.data_root, cfg, f"{split}_images") + cand2 = os.path.join(self.data_root, cfg, "images") + for c in (cand1, cand2): + if os.path.isdir(c): + return c + return None + + def _load_and_transform_one(self, path: str): + try: + with Image.open(path) as im: + im = im.convert("RGB") + except FileNotFoundError: + return None + except Exception: + return None + + return self._tfm_crop(im) + + @staticmethod + def _guess_image_key(cols: List[str]) -> str: + for k in ("images", "image_paths", "imgs", "paths", "image"): + if k in cols: + return k + raise KeyError(f"Cannot find image column among {cols}") + + def _resolve_configs_from_local(self) -> List[str]: + cfgs: List[str] = [] + + if self.local_data_root: + root = Path(self.local_data_root) + if root.is_dir(): + for entry in sorted(root.iterdir()): + if not entry.is_dir(): + continue + if self._has_split_data(entry): + cfgs.append(entry.name) + + if not cfgs and self.local_data_files: + cfgs = [k for k in sorted(self.local_data_files.keys()) if k != "default"] + + return cfgs + + def _has_split_data(self, cfg_path: Path) -> bool: + split_dir = cfg_path / self.split + if split_dir.is_dir(): + return True + + alt_dirs = [ + cfg_path / f"{self.split}.dataset", + cfg_path / f"{self.split}.arrow", + ] + for candidate in alt_dirs: + if candidate.is_dir(): + return True + + patterns = [ + cfg_path / self.split / "*.arrow", + cfg_path / self.split / "*.parquet", + cfg_path / f"{self.split}/*.arrow", + cfg_path / f"{self.split}/*.parquet", + cfg_path / f"{self.split}*.arrow", + cfg_path / f"{self.split}*.parquet", + ] + for pattern in patterns: + if glob(str(pattern)): + return True + + return False + + def _load_base_dataset(self, cfg: str, acc: Accelerator) -> Optional[HFDataset]: + base_ds: Optional[HFDataset] = None + + if self.local_data_root is not None: + # print(self.local_data_root) + base_ds = self._load_from_local_root(cfg) + + if base_ds is None and self.local_data_files: + base_ds = self._load_from_local_data_files(cfg) + + if base_ds is not None: + return base_ds + + kwargs = {} + if self._download_config is not None: + kwargs["download_config"] = self._download_config + + if acc.num_processes > 1: + acc.wait_for_everyone() + try: + base_ds = load_dataset( + "TIGER-Lab/Mantis-Instruct", + cfg, + split=self.split, + **kwargs, + ) + except Exception as exc: + if self._download_config is not None: + raise RuntimeError( + f"Failed to load local dataset for config='{cfg}'. " + "Ensure that the dataset is cached or provide 'local_data_root'." + ) from exc + raise + finally: + if acc.num_processes > 1: + acc.wait_for_everyone() + + return base_ds + + def _load_from_local_root(self, cfg: str) -> Optional[HFDataset]: + cfg_root = os.path.join(self.local_data_root, cfg) + if not os.path.exists(cfg_root): + return None + + candidates = [ + cfg_root, + os.path.join(cfg_root, self.split), + os.path.join(cfg_root, f"{self.split}.dataset"), + ] + + for path in candidates: + if not os.path.isdir(path): + continue + try: + loaded = load_from_disk(path) + if isinstance(loaded, DatasetDict): + if self.split in loaded: + return loaded[self.split] + continue + return loaded + except Exception: + continue + + patterns = [ + os.path.join(cfg_root, f"{self.split}.parquet"), + os.path.join(cfg_root, f"{self.split}/*.parquet"), + os.path.join(cfg_root, f"{self.split}_*.parquet"), + os.path.join(cfg_root, f"{self.split}.json"), + os.path.join(cfg_root, f"{self.split}.jsonl"), + os.path.join(cfg_root, f"{self.split}/*.jsonl"), + os.path.join(cfg_root, f"{self.split}.arrow"), + os.path.join(cfg_root, f"{self.split}/*.arrow"), + ] + + for pattern in patterns: + files = sorted(glob(pattern)) + if files: + return self._load_from_files(files) + + return None + + def _load_from_local_data_files(self, cfg: str) -> Optional[HFDataset]: + spec = self.local_data_files.get(cfg) or self.local_data_files.get("default") + if spec is None: + return None + + if isinstance(spec, str): + entries = [spec] + loader = None + elif isinstance(spec, dict): + loader = spec.get("type") or spec.get("loader") or spec.get("format") + files = spec.get(self.split) or spec.get("files") + if files is None: + return None + entries = files if isinstance(files, list) else [files] + else: + entries = list(spec) + loader = None + + resolved_files: list[str] = [] + for entry in entries: + if not entry: + continue + matched = sorted(glob(entry)) + if matched: + resolved_files.extend(matched) + elif os.path.exists(entry): + resolved_files.append(entry) + + if not resolved_files: + return None + + return self._load_from_files(resolved_files, loader_hint=loader) + + def _load_from_files(self, files: list[str], loader_hint: Optional[str] = None) -> Optional[HFDataset]: + if not files: + return None + + ext = Path(files[0]).suffix.lower() + loader = loader_hint + if loader is None: + if ext in (".parquet",): + loader = "parquet" + elif ext in (".json", ".jsonl"): + loader = "json" + elif ext in (".arrow", ".feather"): + loader = "arrow" + + if loader == "parquet": + return load_dataset("parquet", data_files={self.split: files}, split=self.split) + if loader in {"json", "jsonl"}: + return load_dataset("json", data_files={self.split: files}, split=self.split) + if loader == "arrow": + datasets = [HFDataset.from_file(path) for path in files] + return concatenate_datasets(datasets) if len(datasets) > 1 else datasets[0] + + return None + + +class HFInstructionTextDataset(Dataset): + """Mixed instruction-following text dataset sourced from multiple HF corpora.""" + + HF_SOURCES = ( + { + "name": "openai/gsm8k", + "config": "main", + "split": "train", + "user_key": "question", + "assistant_key": "answer", + }, + { + "name": "qwedsacf/grade-school-math-instructions", + "config": None, + "split": "train", + "user_key": "INSTRUCTION", + "assistant_key": "RESPONSE", + }, + { + "name": "alespalla/chatbot_instruction_prompts", + "config": None, + "split": "train", + "user_key": "prompt", + "assistant_key": "response", + }, + { + "name": "TIGER-Lab/MathInstruct", + "config": None, + "split": "train", + "user_key": "instruction", + "assistant_key": "output", + }, + ) + + def __init__( + self, + *, + split: str = "train", + max_samples_per_source: Optional[int] = None, + max_total_samples: Optional[int] = None, + seed: int = 42, + ) -> None: + self.split = split + self.seed = seed + self.samples: List[str] = [] + + rng = random.Random(seed) + + for source in self.HF_SOURCES: + desired_split = source.get("split", split) + try: + dataset_name = source["name"] + dataset_config = source.get("config") + if dataset_config is not None: + hf_ds = load_dataset(dataset_name, dataset_config, split=desired_split) + else: + hf_ds = load_dataset(dataset_name, split=desired_split) + except Exception as exc: + print(f"[HFInstructionTextDataset] Failed to load {source['name']}: {exc}") + continue + + if max_samples_per_source is not None and len(hf_ds) > max_samples_per_source: + hf_ds = hf_ds.shuffle(seed=seed).select(range(max_samples_per_source)) + + user_key = source["user_key"] + assistant_key = source["assistant_key"] + + for example in hf_ds: + user_raw = str(example.get(user_key, "")).strip() + assistant_raw = str(example.get(assistant_key, "")).strip() + if not user_raw or not assistant_raw: + continue + + formatted = self._format_dialogue(user_raw, assistant_raw) + if formatted: + self.samples.append(formatted) + + if not self.samples: + raise ValueError("HFInstructionTextDataset loaded zero valid samples.") + + rng.shuffle(self.samples) + + if max_total_samples is not None: + self.samples = self.samples[: max_total_samples] + + @staticmethod + def _format_dialogue(user_text: str, assistant_text: str) -> str: + return ( + "<|start_header_id|>user<|end_header_id|>\n" + f"{user_text}\n" + "<|eot_id><|start_header_id|>assistant<|end_header_id|>\n" + f"{assistant_text}" + ) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, str]: + return {"input_ids": self.samples[index]} + + @staticmethod + def collate_fn(batch: List[Dict[str, str]]) -> Dict[str, List[str]]: + return {"input_ids": [example["input_ids"] for example in batch]} + + +class TextToImage2MDataset(Dataset): + """Loads jackyhate/text-to-image-2M for text-to-image fine-tuning.""" + + def __init__( + self, + split: str = "train", + resolution: int = 256, + dataset_name: str = "jackyhate/text-to-image-2M", + cache_dir: str | None = None, + local_files_only: bool = False, + ) -> None: + self.resolution = resolution + self.dataset_name = dataset_name + self.cache_dir = cache_dir + self.local_files_only = local_files_only + + download_cfg = None + if local_files_only: + download_cfg = DownloadConfig(local_files_only=True) + + self._dataset = load_dataset( + dataset_name, + split=split, + cache_dir=cache_dir, + download_config=download_cfg, + ) + + def __len__(self) -> int: + return len(self._dataset) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + sample = self._dataset[idx] + + prompt = None + json_meta = sample.get("json") + if isinstance(json_meta, dict): + prompt = json_meta.get("prompt") + if prompt is None: + prompt = sample.get("prompt", "") + + image_field = sample.get("jpg") or sample.get("image") + if image_field is None: + raise KeyError("Expected image field 'jpg' in text-to-image-2M sample") + + if isinstance(image_field, Image.Image): + image = image_field.convert("RGB") + elif isinstance(image_field, bytes): + image = Image.open(BytesIO(image_field)).convert("RGB") + else: + image = Image.fromarray(np.array(image_field)).convert("RGB") + + image_tensor = utils_image_transform(image, self.resolution) + + return { + "input_prompt": prompt, + "output_prompt": None, + "edit_prompt": None, + "inverse_prompt": None, + "input_image": image_tensor, + "output_image": image_tensor, + } + +class HQEditX2IDataset(Dataset): + + def __init__( + self, + split: str = "train", + resolution: int = 256, + dataset_name: str = "UCSC-VLAA/HQ-Edit", + cache_dir: str = "/home/work/AIDAS/huggingface/datasets", + ): + self.resolution = resolution + self.cache_dir = cache_dir # retained for API compatibility + + self._dataset = load_dataset(dataset_name, split=split) + + def __len__(self) -> int: + return len(self._dataset) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + sample = self._dataset[idx] + + input_tensor = utils_image_transform(sample['input_image'].convert("RGB"), self.resolution) + output_tensor = utils_image_transform(sample['output_image'].convert("RGB"), self.resolution) + + return { + "input_prompt": sample["input"], + "output_prompt": sample["output"], + "edit_prompt": sample["edit"], + "inverse_prompt": sample["inverse_edit"], + "input_image": input_tensor, + "output_image": output_tensor, + } + + +class CombinedX2IDataset(Dataset): + """Round-robin combination of multiple x2i-style datasets.""" + + def __init__(self, datasets: Sequence[Dataset]): + if not datasets: + raise ValueError("CombinedX2IDataset requires at least one dataset.") + self.datasets = list(datasets) + self.lengths = [len(ds) for ds in self.datasets] + if any(length == 0 for length in self.lengths): + raise ValueError("Underlying x2i dataset has zero length.") + self.cumulative = list(itertools.accumulate(self.lengths)) + self.total_length = self.cumulative[-1] + + def __len__(self) -> int: + return self.total_length + + def __getitem__(self, idx: int) -> Dict[str, Any]: + if idx < 0 or idx >= self.total_length: + raise IndexError(f"Index {idx} out of bounds for CombinedX2IDataset of length {self.total_length}") + + dataset_idx = bisect.bisect_right(self.cumulative, idx) + prev = self.cumulative[dataset_idx - 1] if dataset_idx > 0 else 0 + local_idx = idx - prev + return self.datasets[dataset_idx][local_idx] + + +class OpenImageI2IDataset(Dataset): + """ + Image-to-image dataset built from local Open Images edit JSONL files. + + Supports three JSONL schemas: + * SFT-style single turn edits (text + output_image + local_input_image) + * Preference data; only positive edits (output_image) are used by default + * Multi-turn edits which are flattened into single-turn pairs + """ + + def __init__( + self, + resolution: int = 256, + image_root: str | None = None, + sft_jsonl: Union[str, Sequence[str], None] = None, + pref_jsonl: Union[str, Sequence[str], None] = None, + multi_turn_jsonl: Union[str, Sequence[str], None] = None, + prefer_summarized_text: bool = True, + pref_positive_only: bool = True, + skip_missing: bool = True, + max_samples_per_source: int | None = None, + max_total_samples: int | None = None, + seed: int | None = None, + ) -> None: + self.resolution = resolution + self.image_root = image_root + self.prefer_summarized_text = prefer_summarized_text + self.pref_positive_only = pref_positive_only + self.skip_missing = skip_missing + self._rng = random.Random(seed if seed is not None else 0) + self._per_source_limit = self._coerce_positive_int(max_samples_per_source) + self._total_limit = self._coerce_positive_int(max_total_samples) + + self._samples: list[dict[str, str]] = [] + self._stats: dict[str, int] = { + "sft": 0, + "pref": 0, + "multi_turn": 0, + "missing_paths": 0, + "invalid_records": 0, + } + + sft_paths = self._coerce_paths(sft_jsonl) + pref_paths = self._coerce_paths(pref_jsonl) + multi_turn_paths = self._coerce_paths(multi_turn_jsonl) + + for path in sft_paths: + self._samples.extend(self._load_single_turn_file(path, source_key="sft")) + + for path in pref_paths: + if not self.pref_positive_only: + logger.warning("OpenImageI2IDataset currently only supports positive preference pairs.") + self._samples.extend(self._load_single_turn_file(path, source_key="pref")) + + for path in multi_turn_paths: + self._samples.extend(self._load_multi_turn_file(path)) + + if self._total_limit is not None and len(self._samples) > self._total_limit: + self._rng.shuffle(self._samples) + self._samples = self._samples[: self._total_limit] + + if not self._samples: + raise ValueError("OpenImageI2IDataset could not load any valid examples.") + + logger.info( + "Loaded %d OpenImage i2i samples (sft=%d, pref=%d, multi_turn=%d, missing_paths=%d, invalid=%d).", + len(self._samples), + self._stats["sft"], + self._stats["pref"], + self._stats["multi_turn"], + self._stats["missing_paths"], + self._stats["invalid_records"], + ) + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + record = self._samples[idx] + input_image = Image.open(record["input_path"]).convert("RGB") + target_image = Image.open(record["target_path"]).convert("RGB") + + input_tensor = utils_image_transform(input_image, self.resolution) + target_tensor = utils_image_transform(target_image, self.resolution) + + prompt = record["prompt"] + + return { + "input_prompt": prompt, + "output_prompt": None, + "edit_prompt": prompt, + "inverse_prompt": None, + "input_image": input_tensor, + "output_image": target_tensor, + } + + def _load_single_turn_file(self, path: str, *, source_key: str) -> list[dict[str, str]]: + file_path = os.path.abspath(os.path.expanduser(path)) + if not os.path.exists(file_path): + logger.warning("OpenImageI2IDataset: JSONL file not found: %s", file_path) + return [] + + base_dir = os.path.dirname(file_path) + samples: list[dict[str, str]] = [] + with open(file_path, "r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + except json.JSONDecodeError: + self._stats["invalid_records"] += 1 + continue + + prompt = self._select_prompt(record) + input_path = self._resolve_path(record.get("local_input_image"), base_dir=base_dir) + output_path = self._resolve_path(record.get("output_image"), base_dir=base_dir) + sample = self._build_sample(prompt, input_path, output_path) + if sample: + samples.append(sample) + + if self._per_source_limit is not None and len(samples) > self._per_source_limit: + self._rng.shuffle(samples) + samples = samples[: self._per_source_limit] + + self._stats[source_key] += len(samples) + return samples + + def _load_multi_turn_file(self, path: str) -> list[dict[str, str]]: + file_path = os.path.abspath(os.path.expanduser(path)) + if not os.path.exists(file_path): + logger.warning("OpenImageI2IDataset: JSONL file not found: %s", file_path) + return [] + + base_dir = os.path.dirname(file_path) + samples: list[dict[str, str]] = [] + with open(file_path, "r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + except json.JSONDecodeError: + self._stats["invalid_records"] += 1 + continue + + multi_samples = self._expand_multi_turn(record, base_dir=base_dir) + if multi_samples: + samples.extend(multi_samples) + + if self._per_source_limit is not None and len(samples) > self._per_source_limit: + self._rng.shuffle(samples) + samples = samples[: self._per_source_limit] + + self._stats["multi_turn"] += len(samples) + return samples + + def _expand_multi_turn(self, record: dict, *, base_dir: str) -> list[dict[str, str]]: + prompts = record.get("metadata_edit_turn_prompts") or [] + files = record.get("files") or [] + if not prompts or not files: + self._stats["invalid_records"] += 1 + return [] + + outputs: dict[int, str] = {} + final_image: str | None = None + for entry in files: + file_id = entry.get("id") + url = entry.get("url") + if not file_id or not url: + continue + if file_id.startswith("edit_turn"): + try: + idx = int(file_id.replace("edit_turn", "").strip()) + except ValueError: + continue + outputs[idx] = url + elif file_id == "final_image": + final_image = url + + current_input = self._resolve_path(record.get("local_input_image"), base_dir=base_dir) + if not current_input: + return [] + + samples: list[dict[str, str]] = [] + for turn_idx, prompt in enumerate(prompts, start=1): + target_rel = outputs.get(turn_idx) + if target_rel is None: + if turn_idx == len(prompts): + target_rel = final_image + else: + break + target_path = self._resolve_path(target_rel, base_dir=base_dir) + if not target_path: + break + + sample = self._build_sample(prompt, current_input, target_path) + if not sample: + break + + samples.append(sample) + current_input = target_path + + return samples + + def _select_prompt(self, record: dict) -> str | None: + if self.prefer_summarized_text and record.get("summarized_text"): + return record.get("summarized_text") + if record.get("text"): + return record.get("text") + return record.get("metadata_edit_turn_prompt") + + def _build_sample(self, prompt: str | None, input_path: str | None, target_path: str | None) -> dict[str, str] | None: + if not prompt: + self._stats["invalid_records"] += 1 + return None + if not input_path or not target_path: + return None + return { + "prompt": str(prompt).strip(), + "input_path": input_path, + "target_path": target_path, + } + + def _resolve_path(self, path: str | None, *, base_dir: str | None = None) -> str | None: + if not path or path.startswith("http://") or path.startswith("https://"): + return None + + candidates: list[str] = [] + normalized = path.replace("\\", "/") + + if os.path.isabs(normalized): + candidates.append(os.path.normpath(normalized)) + else: + if self.image_root: + candidates.append(os.path.normpath(os.path.join(self.image_root, normalized))) + if base_dir: + candidates.append(os.path.normpath(os.path.join(base_dir, normalized))) + + for candidate in candidates: + if not self.skip_missing or os.path.exists(candidate): + return candidate + + self._stats["missing_paths"] += 1 + return None + + def _coerce_paths(self, value: Union[str, Sequence[str], None]) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + values = [value] + else: + values = [item for item in value if item] + return [os.path.abspath(os.path.expanduser(path)) for path in values] + + @staticmethod + def _coerce_positive_int(value: Any) -> int | None: + if value is None: + return None + try: + int_value = int(value) + except (TypeError, ValueError): + return None + return int_value if int_value > 0 else None + + +# import os, socket +# from typing import Optional + +# def _dist_identity(): +# """Return a dict with rank info from env/torch if available.""" +# info = {} +# # Env fallbacks for different launchers (torchrun/SLURM/MPI) +# def _get(*keys) -> Optional[int]: +# for k in keys: +# v = os.environ.get(k) +# if v is not None: +# try: +# return int(v) +# except ValueError: +# return None +# return None + +# info["rank"] = _get("RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK") +# info["local_rank"] = _get("LOCAL_RANK", "SLURM_LOCALID", "MPI_LOCALRANKID") +# info["node_rank"] = _get("NODE_RANK", "SLURM_NODEID") +# info["world_size"] = _get("WORLD_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE") +# info["hostname"] = socket.gethostname() +# info["pid"] = os.getpid() + +# # Optional: torch.distributed status +# try: +# import torch.distributed as dist +# info["dist_initialized"] = dist.is_available() and dist.is_initialized() +# if info["dist_initialized"]: +# info["rank"] = dist.get_rank() +# info["world_size"] = dist.get_world_size() +# info["backend"] = dist.get_backend() +# except Exception: +# info["dist_initialized"] = False + +# # Optional: DataLoader worker ID +# try: +# from torch.utils.data import get_worker_info +# wi = get_worker_info() +# info["worker_id"] = wi.id if wi is not None else None +# except Exception: +# info["worker_id"] = None + +# return info + +# def _whoami_str(): +# i = _dist_identity() +# return ( +# f"[PROC] rank={i['rank']} local_rank={i['local_rank']} node_rank={i['node_rank']} " +# f"world={i['world_size']} worker={i['worker_id']} " +# f"host={i['hostname']} pid={i['pid']} " +# f"{'(backend='+i['backend']+')' if i.get('backend') else ''}" +# ) + + +if __name__ == '__main__': + pass diff --git a/MMaDA/training/imagenet_dataset.py b/MMaDA/training/imagenet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c0404650730ff787b47978af31a81e8a4ce2eabe --- /dev/null +++ b/MMaDA/training/imagenet_dataset.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2024 NUS Show Lab. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from typing import Any, Callable, Optional + +import torch +from torchvision.datasets.folder import DatasetFolder, default_loader +from training.utils import image_transform + + +class ImageNetDataset(DatasetFolder): + def __init__( + self, + root: str, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + image_size=256, + ): + IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") + + self.transform = image_transform + self.image_size = image_size + + super().__init__( + root, + loader, + IMG_EXTENSIONS if is_valid_file is None else None, + transform=self.transform, + target_transform=None, + is_valid_file=is_valid_file, + ) + + with open('./training/imagenet_label_mapping', 'r') as f: + self.labels = {} + for l in f: + num, description = l.split(":") + self.labels[int(num)] = description.strip() + + print("ImageNet dataset loaded.") + + def __getitem__(self, idx): + + try: + path, target = self.samples[idx] + image = self.loader(path) + image = self.transform(image, resolution=self.image_size) + input_ids = "{}".format(self.labels[target]) + class_ids = torch.tensor(target) + + return {'images': image, 'input_ids': input_ids, 'class_ids': class_ids} + + except Exception as e: + print(e) + return self.__getitem__(idx+1) + + def collate_fn(self, batch): + batched = collections.defaultdict(list) + for data in batch: + for k, v in data.items(): + batched[k].append(v) + for k, v in batched.items(): + if k not in ('input_ids'): + batched[k] = torch.stack(v, dim=0) + + return batched + + +if __name__ == '__main__': + pass diff --git a/MMaDA/training/imagenet_label_mapping b/MMaDA/training/imagenet_label_mapping new file mode 100644 index 0000000000000000000000000000000000000000..1f3555d09437e65ed458c98e84c481465fd7b30a --- /dev/null +++ b/MMaDA/training/imagenet_label_mapping @@ -0,0 +1,1000 @@ +0: tench +1: goldfish +2: great white shark +3: tiger shark +4: hammerhead +5: electric ray +6: stingray +7: cock +8: hen +9: ostrich +10: brambling +11: goldfinch +12: house finch +13: junco +14: indigo bunting +15: robin +16: bulbul +17: jay +18: magpie +19: chickadee +20: water ouzel +21: kite +22: bald eagle +23: vulture +24: great grey owl +25: European fire salamander +26: common newt +27: eft +28: spotted salamander +29: axolotl +30: bullfrog +31: tree frog +32: tailed frog +33: loggerhead +34: leatherback turtle +35: mud turtle +36: terrapin +37: box turtle +38: banded gecko +39: common iguana +40: American chameleon +41: whiptail +42: agama +43: frilled lizard +44: alligator lizard +45: Gila monster +46: green lizard +47: African chameleon +48: Komodo dragon +49: African crocodile +50: American alligator +51: triceratops +52: thunder snake +53: ringneck snake +54: hognose snake +55: green snake +56: king snake +57: garter snake +58: water snake +59: vine snake +60: night snake +61: boa constrictor +62: rock python +63: Indian cobra +64: green mamba +65: sea snake +66: horned viper +67: diamondback +68: sidewinder +69: trilobite +70: harvestman +71: scorpion +72: black and gold garden spider +73: barn spider +74: garden spider +75: black widow +76: tarantula +77: wolf spider +78: tick +79: centipede +80: black grouse +81: ptarmigan +82: ruffed grouse +83: prairie chicken +84: peacock +85: quail +86: partridge +87: African grey +88: macaw +89: sulphur-crested cockatoo +90: lorikeet +91: coucal +92: bee eater +93: hornbill +94: hummingbird +95: jacamar +96: toucan +97: drake +98: red-breasted merganser +99: goose +100: black swan +101: tusker +102: echidna +103: platypus +104: wallaby +105: koala +106: wombat +107: jellyfish +108: sea anemone +109: brain coral +110: flatworm +111: nematode +112: conch +113: snail +114: slug +115: sea slug +116: chiton +117: chambered nautilus +118: Dungeness crab +119: rock crab +120: fiddler crab +121: king crab +122: American lobster +123: spiny lobster +124: crayfish +125: hermit crab +126: isopod +127: white stork +128: black stork +129: spoonbill +130: flamingo +131: little blue heron +132: American egret +133: bittern +134: crane +135: limpkin +136: European gallinule +137: American coot +138: bustard +139: ruddy turnstone +140: red-backed sandpiper +141: redshank +142: dowitcher +143: oystercatcher +144: pelican +145: king penguin +146: albatross +147: grey whale +148: killer whale +149: dugong +150: sea lion +151: Chihuahua +152: Japanese spaniel +153: Maltese dog +154: Pekinese +155: Shih-Tzu +156: Blenheim spaniel +157: papillon +158: toy terrier +159: Rhodesian ridgeback +160: Afghan hound +161: basset +162: beagle +163: bloodhound +164: bluetick +165: black-and-tan coonhound +166: Walker hound +167: English foxhound +168: redbone +169: borzoi +170: Irish wolfhound +171: Italian greyhound +172: whippet +173: Ibizan hound +174: Norwegian elkhound +175: otterhound +176: Saluki +177: Scottish deerhound +178: Weimaraner +179: Staffordshire bullterrier +180: American Staffordshire terrier +181: Bedlington terrier +182: Border terrier +183: Kerry blue terrier +184: Irish terrier +185: Norfolk terrier +186: Norwich terrier +187: Yorkshire terrier +188: wire-haired fox terrier +189: Lakeland terrier +190: Sealyham terrier +191: Airedale +192: cairn +193: Australian terrier +194: Dandie Dinmont +195: Boston bull +196: miniature schnauzer +197: giant schnauzer +198: standard schnauzer +199: Scotch terrier +200: Tibetan terrier +201: silky terrier +202: soft-coated wheaten terrier +203: West Highland white terrier +204: Lhasa +205: flat-coated retriever +206: curly-coated retriever +207: golden retriever +208: Labrador retriever +209: Chesapeake Bay retriever +210: German short-haired pointer +211: vizsla +212: English setter +213: Irish setter +214: Gordon setter +215: Brittany spaniel +216: clumber +217: English springer +218: Welsh springer spaniel +219: cocker spaniel +220: Sussex spaniel +221: Irish water spaniel +222: kuvasz +223: schipperke +224: groenendael +225: malinois +226: briard +227: kelpie +228: komondor +229: Old English sheepdog +230: Shetland sheepdog +231: collie +232: Border collie +233: Bouvier des Flandres +234: Rottweiler +235: German shepherd +236: Doberman +237: miniature pinscher +238: Greater Swiss Mountain dog +239: Bernese mountain dog +240: Appenzeller +241: EntleBucher +242: boxer +243: bull mastiff +244: Tibetan mastiff +245: French bulldog +246: Great Dane +247: Saint Bernard +248: Eskimo dog +249: malamute +250: Siberian husky +251: dalmatian +252: affenpinscher +253: basenji +254: pug +255: Leonberg +256: Newfoundland +257: Great Pyrenees +258: Samoyed +259: Pomeranian +260: chow +261: keeshond +262: Brabancon griffon +263: Pembroke +264: Cardigan +265: toy poodle +266: miniature poodle +267: standard poodle +268: Mexican hairless +269: timber wolf +270: white wolf +271: red wolf +272: coyote +273: dingo +274: dhole +275: African hunting dog +276: hyena +277: red fox +278: kit fox +279: Arctic fox +280: grey fox +281: tabby +282: tiger cat +283: Persian cat +284: Siamese cat +285: Egyptian cat +286: cougar +287: lynx +288: leopard +289: snow leopard +290: jaguar +291: lion +292: tiger +293: cheetah +294: brown bear +295: American black bear +296: ice bear +297: sloth bear +298: mongoose +299: meerkat +300: tiger beetle +301: ladybug +302: ground beetle +303: long-horned beetle +304: leaf beetle +305: dung beetle +306: rhinoceros beetle +307: weevil +308: fly +309: bee +310: ant +311: grasshopper +312: cricket +313: walking stick +314: cockroach +315: mantis +316: cicada +317: leafhopper +318: lacewing +319: dragonfly +320: damselfly +321: admiral +322: ringlet +323: monarch +324: cabbage butterfly +325: sulphur butterfly +326: lycaenid +327: starfish +328: sea urchin +329: sea cucumber +330: wood rabbit +331: hare +332: Angora +333: hamster +334: porcupine +335: fox squirrel +336: marmot +337: beaver +338: guinea pig +339: sorrel +340: zebra +341: hog +342: wild boar +343: warthog +344: hippopotamus +345: ox +346: water buffalo +347: bison +348: ram +349: bighorn +350: ibex +351: hartebeest +352: impala +353: gazelle +354: Arabian camel +355: llama +356: weasel +357: mink +358: polecat +359: black-footed ferret +360: otter +361: skunk +362: badger +363: armadillo +364: three-toed sloth +365: orangutan +366: gorilla +367: chimpanzee +368: gibbon +369: siamang +370: guenon +371: patas +372: baboon +373: macaque +374: langur +375: colobus +376: proboscis monkey +377: marmoset +378: capuchin +379: howler monkey +380: titi +381: spider monkey +382: squirrel monkey +383: Madagascar cat +384: indri +385: Indian elephant +386: African elephant +387: lesser panda +388: giant panda +389: barracouta +390: eel +391: coho +392: rock beauty +393: anemone fish +394: sturgeon +395: gar +396: lionfish +397: puffer +398: abacus +399: abaya +400: academic gown +401: accordion +402: acoustic guitar +403: aircraft carrier +404: airliner +405: airship +406: altar +407: ambulance +408: amphibian +409: analog clock +410: apiary +411: apron +412: ashcan +413: assault rifle +414: backpack +415: bakery +416: balance beam +417: balloon +418: ballpoint +419: Band Aid +420: banjo +421: bannister +422: barbell +423: barber chair +424: barbershop +425: barn +426: barometer +427: barrel +428: barrow +429: baseball +430: basketball +431: bassinet +432: bassoon +433: bathing cap +434: bath towel +435: bathtub +436: beach wagon +437: beacon +438: beaker +439: bearskin +440: beer bottle +441: beer glass +442: bell cote +443: bib +444: bicycle-built-for-two +445: bikini +446: binder +447: binoculars +448: birdhouse +449: boathouse +450: bobsled +451: bolo tie +452: bonnet +453: bookcase +454: bookshop +455: bottlecap +456: bow +457: bow tie +458: brass +459: brassiere +460: breakwater +461: breastplate +462: broom +463: bucket +464: buckle +465: bulletproof vest +466: bullet train +467: butcher shop +468: cab +469: caldron +470: candle +471: cannon +472: canoe +473: can opener +474: cardigan +475: car mirror +476: carousel +477: carpenters kit +478: carton +479: car wheel +480: cash machine +481: cassette +482: cassette player +483: castle +484: catamaran +485: CD player +486: cello +487: cellular telephone +488: chain +489: chainlink fence +490: chain mail +491: chain saw +492: chest +493: chiffonier +494: chime +495: china cabinet +496: Christmas stocking +497: church +498: cinema +499: cleaver +500: cliff dwelling +501: cloak +502: clog +503: cocktail shaker +504: coffee mug +505: coffeepot +506: coil +507: combination lock +508: computer keyboard +509: confectionery +510: container ship +511: convertible +512: corkscrew +513: cornet +514: cowboy boot +515: cowboy hat +516: cradle +517: crane +518: crash helmet +519: crate +520: crib +521: Crock Pot +522: croquet ball +523: crutch +524: cuirass +525: dam +526: desk +527: desktop computer +528: dial telephone +529: diaper +530: digital clock +531: digital watch +532: dining table +533: dishrag +534: dishwasher +535: disk brake +536: dock +537: dogsled +538: dome +539: doormat +540: drilling platform +541: drum +542: drumstick +543: dumbbell +544: Dutch oven +545: electric fan +546: electric guitar +547: electric locomotive +548: entertainment center +549: envelope +550: espresso maker +551: face powder +552: feather boa +553: file +554: fireboat +555: fire engine +556: fire screen +557: flagpole +558: flute +559: folding chair +560: football helmet +561: forklift +562: fountain +563: fountain pen +564: four-poster +565: freight car +566: French horn +567: frying pan +568: fur coat +569: garbage truck +570: gasmask +571: gas pump +572: goblet +573: go-kart +574: golf ball +575: golfcart +576: gondola +577: gong +578: gown +579: grand piano +580: greenhouse +581: grille +582: grocery store +583: guillotine +584: hair slide +585: hair spray +586: half track +587: hammer +588: hamper +589: hand blower +590: hand-held computer +591: handkerchief +592: hard disc +593: harmonica +594: harp +595: harvester +596: hatchet +597: holster +598: home theater +599: honeycomb +600: hook +601: hoopskirt +602: horizontal bar +603: horse cart +604: hourglass +605: iPod +606: iron +607: jack-o-lantern +608: jean +609: jeep +610: jersey +611: jigsaw puzzle +612: jinrikisha +613: joystick +614: kimono +615: knee pad +616: knot +617: lab coat +618: ladle +619: lampshade +620: laptop +621: lawn mower +622: lens cap +623: letter opener +624: library +625: lifeboat +626: lighter +627: limousine +628: liner +629: lipstick +630: Loafer +631: lotion +632: loudspeaker +633: loupe +634: lumbermill +635: magnetic compass +636: mailbag +637: mailbox +638: maillot +639: maillot +640: manhole cover +641: maraca +642: marimba +643: mask +644: matchstick +645: maypole +646: maze +647: measuring cup +648: medicine chest +649: megalith +650: microphone +651: microwave +652: military uniform +653: milk can +654: minibus +655: miniskirt +656: minivan +657: missile +658: mitten +659: mixing bowl +660: mobile home +661: Model T +662: modem +663: monastery +664: monitor +665: moped +666: mortar +667: mortarboard +668: mosque +669: mosquito net +670: motor scooter +671: mountain bike +672: mountain tent +673: mouse +674: mousetrap +675: moving van +676: muzzle +677: nail +678: neck brace +679: necklace +680: nipple +681: notebook +682: obelisk +683: oboe +684: ocarina +685: odometer +686: oil filter +687: organ +688: oscilloscope +689: overskirt +690: oxcart +691: oxygen mask +692: packet +693: paddle +694: paddlewheel +695: padlock +696: paintbrush +697: pajama +698: palace +699: panpipe +700: paper towel +701: parachute +702: parallel bars +703: park bench +704: parking meter +705: passenger car +706: patio +707: pay-phone +708: pedestal +709: pencil box +710: pencil sharpener +711: perfume +712: Petri dish +713: photocopier +714: pick +715: pickelhaube +716: picket fence +717: pickup +718: pier +719: piggy bank +720: pill bottle +721: pillow +722: ping-pong ball +723: pinwheel +724: pirate +725: pitcher +726: plane +727: planetarium +728: plastic bag +729: plate rack +730: plow +731: plunger +732: Polaroid camera +733: pole +734: police van +735: poncho +736: pool table +737: pop bottle +738: pot +739: potters wheel +740: power drill +741: prayer rug +742: printer +743: prison +744: projectile +745: projector +746: puck +747: punching bag +748: purse +749: quill +750: quilt +751: racer +752: racket +753: radiator +754: radio +755: radio telescope +756: rain barrel +757: recreational vehicle +758: reel +759: reflex camera +760: refrigerator +761: remote control +762: restaurant +763: revolver +764: rifle +765: rocking chair +766: rotisserie +767: rubber eraser +768: rugby ball +769: rule +770: running shoe +771: safe +772: safety pin +773: saltshaker +774: sandal +775: sarong +776: sax +777: scabbard +778: scale +779: school bus +780: schooner +781: scoreboard +782: screen +783: screw +784: screwdriver +785: seat belt +786: sewing machine +787: shield +788: shoe shop +789: shoji +790: shopping basket +791: shopping cart +792: shovel +793: shower cap +794: shower curtain +795: ski +796: ski mask +797: sleeping bag +798: slide rule +799: sliding door +800: slot +801: snorkel +802: snowmobile +803: snowplow +804: soap dispenser +805: soccer ball +806: sock +807: solar dish +808: sombrero +809: soup bowl +810: space bar +811: space heater +812: space shuttle +813: spatula +814: speedboat +815: spider web +816: spindle +817: sports car +818: spotlight +819: stage +820: steam locomotive +821: steel arch bridge +822: steel drum +823: stethoscope +824: stole +825: stone wall +826: stopwatch +827: stove +828: strainer +829: streetcar +830: stretcher +831: studio couch +832: stupa +833: submarine +834: suit +835: sundial +836: sunglass +837: sunglasses +838: sunscreen +839: suspension bridge +840: swab +841: sweatshirt +842: swimming trunks +843: swing +844: switch +845: syringe +846: table lamp +847: tank +848: tape player +849: teapot +850: teddy +851: television +852: tennis ball +853: thatch +854: theater curtain +855: thimble +856: thresher +857: throne +858: tile roof +859: toaster +860: tobacco shop +861: toilet seat +862: torch +863: totem pole +864: tow truck +865: toyshop +866: tractor +867: trailer truck +868: tray +869: trench coat +870: tricycle +871: trimaran +872: tripod +873: triumphal arch +874: trolleybus +875: trombone +876: tub +877: turnstile +878: typewriter keyboard +879: umbrella +880: unicycle +881: upright +882: vacuum +883: vase +884: vault +885: velvet +886: vending machine +887: vestment +888: viaduct +889: violin +890: volleyball +891: waffle iron +892: wall clock +893: wallet +894: wardrobe +895: warplane +896: washbasin +897: washer +898: water bottle +899: water jug +900: water tower +901: whiskey jug +902: whistle +903: wig +904: window screen +905: window shade +906: Windsor tie +907: wine bottle +908: wing +909: wok +910: wooden spoon +911: wool +912: worm fence +913: wreck +914: yawl +915: yurt +916: web site +917: comic book +918: crossword puzzle +919: street sign +920: traffic light +921: book jacket +922: menu +923: plate +924: guacamole +925: consomme +926: hot pot +927: trifle +928: ice cream +929: ice lolly +930: French loaf +931: bagel +932: pretzel +933: cheeseburger +934: hotdog +935: mashed potato +936: head cabbage +937: broccoli +938: cauliflower +939: zucchini +940: spaghetti squash +941: acorn squash +942: butternut squash +943: cucumber +944: artichoke +945: bell pepper +946: cardoon +947: mushroom +948: Granny Smith +949: strawberry +950: orange +951: lemon +952: fig +953: pineapple +954: banana +955: jackfruit +956: custard apple +957: pomegranate +958: hay +959: carbonara +960: chocolate sauce +961: dough +962: meat loaf +963: pizza +964: potpie +965: burrito +966: red wine +967: espresso +968: cup +969: eggnog +970: alp +971: bubble +972: cliff +973: coral reef +974: geyser +975: lakeside +976: promontory +977: sandbar +978: seashore +979: valley +980: volcano +981: ballplayer +982: groom +983: scuba diver +984: rapeseed +985: daisy +986: yellow ladys slipper +987: corn +988: acorn +989: hip +990: buckeye +991: coral fungus +992: agaric +993: gyromitra +994: stinkhorn +995: earthstar +996: hen-of-the-woods +997: bolete +998: ear +999: toilet tissue \ No newline at end of file diff --git a/MMaDA/training/optimizer.py b/MMaDA/training/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a499b90a2827c96aadc4adaccf712c67210829df --- /dev/null +++ b/MMaDA/training/optimizer.py @@ -0,0 +1,81 @@ +# Copyright 2023 Google Research. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""PyTorch implementation of the Lion optimizer.""" +import torch +from torch.optim.optimizer import Optimizer + + +class Lion(Optimizer): + r"""Implements Lion algorithm.""" + + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0, **kwargs): + """Initialize the hyperparameters. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-4) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.99)) + weight_decay (float, optional): weight decay coefficient (default: 0) + """ + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group["lr"] * group["weight_decay"]) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + + exp_avg = state["exp_avg"] + beta1, beta2 = group["betas"] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + p.add_(torch.sign(update), alpha=-group["lr"]) + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss diff --git a/MMaDA/training/prompting_utils.py b/MMaDA/training/prompting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c83d2540ad3fa1c23d78cc6dd51529fb5ef8048 --- /dev/null +++ b/MMaDA/training/prompting_utils.py @@ -0,0 +1,1906 @@ +# coding=utf-8 +# Copyright 2025 MMaDA team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +reserved_token_mapping = { + '<|soi|>': 126084, + '<|eoi|>': 126085, + '<|sov|>': 126086, + '<|eov|>': 126087, + '<|t2i|>': 126088, + '<|mmu|>': 126089, + '<|t2v|>': 126090, + '<|v2v|>': 126091, + '<|lvg|>': 126092, + '[iPAD]': 126093, + '<|r2i|>': 126094, + '<|i2i|>': 126095, + '<|s2t|>': 126096, + '<|soa|>': 126097, + '<|eoa|>': 126098, + '<|t2s|>': 126099, + '<|v2t|>': 126100, + '<|s2s|>': 126101, # yejoon: add s2s + '<|v2s|>': 126102, # csjihwanh: added v2s +} + + +import torch +from typing import Union, Optional +class UniversalPrompting(): + def __init__(self, text_tokenizer, + special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", "<|i2i|>","<|s2t|>","<|soa|>","<|eoa|>","<|t2s|>", "<|v2t|>"), + max_text_len=8000, max_audio_len = 384, max_audio_len_short = 256, max_seq_len=377, max_image_len=512, ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=False): + """ + :param text_tokenizer: original text tokenizer + """ + if not use_reserved_token: + self.text_tokenizer = text_tokenizer + self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + self.text_tokenizer.add_tokens(list(special_tokens)) + self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in + special_tokens} + self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id]) + self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id]) + self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id]) + else: + self.text_tokenizer = text_tokenizer + self.sptids_dict = {} + for token, token_id in reserved_token_mapping.items(): + self.sptids_dict[token] = torch.tensor([token_id]) + self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id]) + self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id]) + end_header_tokens = self.text_tokenizer.convert_tokens_to_ids(['<|end_header_id|>']) + if end_header_tokens and len(end_header_tokens) > 0 and end_header_tokens[0]: + self.sptids_dict['<|end_header_id|>'] = torch.tensor(end_header_tokens) + self.sptids_dict['<|eot_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|eot_id|>'])) + self.sptids_dict['<|start_header_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|start_header_id|>'])) + else: + special_tokens_dict = { + 'additional_special_tokens': [ + '<|start_header_id|>', + '<|end_header_id|>', + '<|eot_id|>' + ] + } + num_added = self.text_tokenizer.add_special_tokens(special_tokens_dict) + new_token_id = self.text_tokenizer.convert_tokens_to_ids(['<|end_header_id|>']) + self.sptids_dict['<|end_header_id|>'] = torch.tensor(new_token_id) + self.sptids_dict['<|eot_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|eot_id|>'])) + self.sptids_dict['<|start_header_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|start_header_id|>'])) + # plus 1 because at this time we add a task token before + print(f"self.sptids_dict: {self.sptids_dict}") + self.max_text_len = max_text_len + 1 + self.max_image_len = max_image_len + self.max_audio_len = max_audio_len + self.max_audio_len_short = max_audio_len_short + self.pad_id = reserved_token_mapping['[iPAD]'] + self.ignore_id = ignore_id + self.cond_dropout_prob = cond_dropout_prob + + def t2i_prompt(self, text_ids, image_ids, labels): + + device = image_ids.device + sequence_ids = [] + attention_masks = [] + label_ids = [] + probs = torch.rand(len(text_ids)) + eos_id = self.text_tokenizer.eos_token_id + + for i in range(len(text_ids)): + + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id] + + # randomly dropout text condition + if probs[i] < self.cond_dropout_prob: + temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id] + + if self.max_text_len >= len(temp_ids): + old_len = len(temp_ids) + temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids + temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + image_ids.shape[-1] + 2) + else: + # should add the eos token + temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id] + temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 2) # +2 for two special tokens + # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi] + temp_label_ids = torch.cat([ + # should we predict text tokens when doing image reconstruction? + torch.tensor(temp_ids).to(device), + self.sptids_dict['<|soi|>'].to(device), + labels[i], + self.sptids_dict['<|eoi|>'].to(device) + ], dim=0) + + + temp_ids = torch.cat([ + torch.tensor(temp_ids).to(device), + self.sptids_dict['<|soi|>'].to(device), + image_ids[i], + self.sptids_dict['<|eoi|>'].to(device) + ], dim=0) + + # sequence_ids: [pad]...[pad] <|t2i|> text_1 ... text_n <|soi|> image_1 ... image_m <|eoi|> + temp_masks = torch.tensor(temp_masks).to(device) + sequence_ids.append(temp_ids.unsqueeze(0)) + attention_masks.append(temp_masks.unsqueeze(0)) + label_ids.append(temp_label_ids.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0) + + def i2i_prompt(self, prompts, original_image_tokens, masked_edited_tokens, labels): + """ + Constructs the input sequence for the Image-to-Image task. + Final sequence structure: + <|i2i|> <|soi|> [source_image] <|eoi|> text_1 ... text_n <|soi|> [masked_target_image] <|eoi|> + """ + device = original_image_tokens.device + batch_size = len(prompts) + + sequence_ids = [] + attention_masks = [] + label_ids = [] + + tokenized_prompts = self.text_tokenizer(prompts, add_special_tokens=False, return_tensors=None).input_ids + + for i in range(batch_size): + + # 1. Process text prompts with and + # --------------------------------------------------- + temp_text_ids = [self.text_tokenizer.bos_token_id] + tokenized_prompts[i] + [self.text_tokenizer.eos_token_id] + + if torch.rand(1) < self.cond_dropout_prob: + temp_text_ids = [self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id] + + if self.max_text_len >= len(temp_text_ids): + pad_len = self.max_text_len - len(temp_text_ids) + padded_text_ids = [self.pad_id] * pad_len + temp_text_ids + else: + padded_text_ids = temp_text_ids[:self.max_text_len-1] + [self.text_tokenizer.eos_token_id] + + padded_text_ids_tensor = torch.tensor(padded_text_ids, device=device) + + # 2. Construct the full input sequence (input_ids) + # --------------------------------------------------- + temp_ids = torch.cat([ + self.sptids_dict['<|i2i|>'].to(device), + self.sptids_dict['<|soi|>'].to(device), + original_image_tokens[i], + self.sptids_dict['<|eoi|>'].to(device), + padded_text_ids_tensor, + self.sptids_dict['<|soi|>'].to(device), # Using <|soi|> for target image + masked_edited_tokens[i], + self.sptids_dict['<|eoi|>'].to(device) # Using <|eoi|> for target image + ], dim=0) + + sequence_ids.append(temp_ids.unsqueeze(0)) + + # 3. Construct the labels + # --------------------------------------------------- + len_prefix_ignore = ( + 1 + # <|i2i|> + 1 + len(original_image_tokens[i]) + 1 + # <|soi|>...<|eoi|> + len(padded_text_ids_tensor) + # text + 1 # <|soi|> for target + ) + + ignore_labels = torch.full((len_prefix_ignore,), self.ignore_id, device=device) + + temp_label_ids = torch.cat([ + ignore_labels, + labels[i], + torch.tensor([self.ignore_id], device=device) # Ignore final <|eoi|> + ], dim=0) + + label_ids.append(temp_label_ids.unsqueeze(0)) + + # 4. Construct the attention mask + # --------------------------------------------------- + text_attention_mask = (padded_text_ids_tensor != self.pad_id).long() + + # All non-padding tokens should have attention + len_prefix = 1 + 1 + len(original_image_tokens[i]) + 1 + len_suffix = 1 + len(masked_edited_tokens[i]) + 1 + + prefix_mask = torch.ones(len_prefix, device=device) + suffix_mask = torch.ones(len_suffix, device=device) + + temp_masks = torch.cat([ + prefix_mask, + text_attention_mask, + suffix_mask + ], dim=0) + + attention_masks.append(temp_masks.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0) + + def t2i_gen_prompt(self, text_ids, image_ids): + + device = image_ids.device + sequence_ids = [] + attention_masks = [] + for i in range(len(text_ids)): + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + # note that, llama3 tokenizer automatically add the bot token at first but without eot + temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id] + if self.max_text_len >= len(temp_ids): + old_len = len(temp_ids) + temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids + temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + image_ids.shape[-1] + 2) + else: + # should add the eos token + temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id] + temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 2) # +2 for two special tokens + + # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi] + temp_ids = torch.cat([ + torch.tensor(temp_ids).to(device), + self.sptids_dict['<|soi|>'].to(device), + image_ids[i], + self.sptids_dict['<|eoi|>'].to(device) + ], dim=0) + + temp_masks = torch.tensor(temp_masks).to(device) + sequence_ids.append(temp_ids.unsqueeze(0)) + attention_masks.append(temp_masks.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0) + + def i2i_gen_prompt(self, texts, input_image_tokens, output_image_placeholder): + device = input_image_tokens.device + + if isinstance(texts, str): + texts = [texts] + + batch_size = len(texts) + + sequence_ids_batch = [] + attention_masks_batch = [] + + for i in range(batch_size): + text_item = texts[i] + input_img_item = input_image_tokens[i] + output_img_placeholder_item = output_image_placeholder[i] + + text_ids_list = self.text_tokenizer(text_item)['input_ids'] + + if not text_ids_list: + text_ids_list = [self.text_tokenizer.bos_token_id] + elif text_ids_list[0] != self.text_tokenizer.bos_token_id: + text_ids_list = [self.text_tokenizer.bos_token_id] + text_ids_list + text_ids_list.append(self.text_tokenizer.eos_token_id) + + max_text_len = self.max_text_len + if max_text_len >= len(text_ids_list): + pad_len = max_text_len - len(text_ids_list) + padded_text_ids = [self.pad_id] * pad_len + text_ids_list + text_attention_mask_list = [0] * pad_len + [1] * len(text_ids_list) + else: + padded_text_ids = text_ids_list[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + text_attention_mask_list = [1] * max_text_len + + # [TASK][CONDITION_IMG][CONDITION_TEXT][START_GEN][TARGET_IMG][END_GEN] + temp_ids = torch.cat([ + self.sptids_dict['<|t2i|>'].to(device), + self.sptids_dict['<|soi|>'].to(device), + input_img_item, + self.sptids_dict['<|eoi|>'].to(device), + self.sptids_dict['<|sot|>'].to(device), + torch.tensor(padded_text_ids, dtype=torch.long, device=device), + self.sptids_dict['<|eot|>'].to(device), + self.sptids_dict['<|soi|>'].to(device), + output_img_placeholder_item, + self.sptids_dict['<|eoi|>'].to(device) + ], dim=0) + + temp_masks = torch.cat([ + torch.ones(1, dtype=torch.long, device=device), + torch.ones(1, dtype=torch.long, device=device), + torch.ones_like(input_img_item, dtype=torch.long), + torch.ones(1, dtype=torch.long, device=device), + torch.ones(1, dtype=torch.long, device=device), + torch.tensor(text_attention_mask_list, dtype=torch.long, device=device), + torch.ones(1, dtype=torch.long, device=device), + torch.ones(1, dtype=torch.long, device=device), + torch.ones_like(output_img_placeholder_item, dtype=torch.long), + torch.ones(1, dtype=torch.long, device=device) + ], dim=0) + + sequence_ids_batch.append(temp_ids.unsqueeze(0)) + attention_masks_batch.append(temp_masks.unsqueeze(0)) + + return torch.cat(sequence_ids_batch, dim=0), torch.cat(attention_masks_batch, dim=0) + + def t2s_gen_prompt(self, text_ids, audio_ids): + + device = audio_ids.device + sequence_ids = [] + attention_masks = [] + for i in range(len(text_ids)): + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + # note that, llama3 tokenizer automatically add the bot token at first but without eot + temp_ids = [int(self.sptids_dict['<|t2s|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id] + if self.max_text_len >= len(temp_ids): + old_len = len(temp_ids) + temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids + temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + audio_ids.shape[-1] + 1) + else: + # should add the eos token + temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id] + temp_masks = [1] * (len(temp_ids) + audio_ids.shape[-1] + 1) # +1 for SOA + + # prompting -- [task token] [sot] [text tokens] [eot] [soi] [audio tokens] [eoi] + temp_ids = torch.cat([ + torch.tensor(temp_ids).to(device), + self.sptids_dict['<|soa|>'].to(device), + audio_ids[i], + # self.sptids_dict['<|eoa|>'].to(device) + ], dim=0) + + temp_masks = torch.tensor(temp_masks).to(device) + sequence_ids.append(temp_ids.unsqueeze(0)) + attention_masks.append(temp_masks.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0) + + def t2s_fixed_gen_prompt(self, text_ids, audio_ids): + + device = audio_ids.device + sequence_ids = [] + attention_masks = [] + for i in range(len(text_ids)): + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + # note that, llama3 tokenizer automatically add the bot token at first but without eot + temp_ids = [int(self.sptids_dict['<|t2s|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id] + if self.max_text_len >= len(temp_ids): + old_len = len(temp_ids) + temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids + temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + audio_ids.shape[-1] + 2) + else: + # should add the eos token + temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id] + temp_masks = [1] * (len(temp_ids) + audio_ids.shape[-1] + 2) # +1 for SOA and EOA + + # prompting -- [task token] [sot] [text tokens] [eot] [soi] [audio tokens] [eoi] + temp_ids = torch.cat([ + torch.tensor(temp_ids).to(device), + self.sptids_dict['<|soa|>'].to(device), + audio_ids[i], + self.sptids_dict['<|eoa|>'].to(device) + ], dim=0) + + temp_masks = torch.tensor(temp_masks).to(device) + sequence_ids.append(temp_ids.unsqueeze(0)) + attention_masks.append(temp_masks.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0) + + # language modeling + def lm_prompt(self, text_ids, max_seq_len): + sequence_ids = [] + attention_masks = [] + label_ids = [] + for i in range(len(text_ids)): + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id] + + if max_seq_len >= len(temp_ids): + temp_labels_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids)) + temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids)) + temp_masks = [1] * len(temp_ids) + [0] * (max_seq_len - len(temp_ids)) + else: + # In language modeling, we only process text tokens. We do not add the eos token if the text length + # exceeds the max sequence length + temp_labels_ids = temp_ids[:max_seq_len] + temp_ids = temp_ids[:max_seq_len] + temp_masks = [1] * len(temp_ids) # +2 for two special tokens + + # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi] + temp_ids = torch.tensor(temp_ids) + temp_masks = torch.tensor(temp_masks) + temp_labels_ids = torch.tensor(temp_labels_ids) + sequence_ids.append(temp_ids.unsqueeze(0)) + attention_masks.append(temp_masks.unsqueeze(0)) + label_ids.append(temp_labels_ids.unsqueeze(0)) + + # input_ids, masks, labels + return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0) + + # language modeling + def lm_chat_prompt(self, text_ids, max_seq_len): + sequence_ids = [] + prompt_masks = [] + label_ids = [] + + for i in range(len(text_ids)): + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id] + + if max_seq_len >= len(temp_ids): + temp_labels_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids)) + temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids)) + else: + # In language modeling, we only process text tokens. We do not add the eos token if the text length + # exceeds the max sequence length + temp_labels_ids = temp_ids[:max_seq_len] + temp_ids = temp_ids[:max_seq_len] + + end_header_id = int(self.sptids_dict['<|end_header_id|>']) + end_header_pos = -1 + for pos in range(len(temp_ids) - 1, -1, -1): # å°čÆ•ä»Žę–‡ęœ¬åŗåˆ—äø­åÆ»ę‰¾<|end_header_id|> + if temp_ids[pos] == end_header_id: + end_header_pos = pos + break + if end_header_pos != -1: + prompt_length = end_header_pos + 1 + else: + prompt_length = 0 + temp_masks = [1] * prompt_length + [0] * (len(temp_ids) - prompt_length) + + # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi] + temp_ids = torch.tensor(temp_ids) + temp_masks = torch.tensor(temp_masks) + temp_labels_ids = torch.tensor(temp_labels_ids) + sequence_ids.append(temp_ids.unsqueeze(0)) + prompt_masks.append(temp_masks.unsqueeze(0)) + label_ids.append(temp_labels_ids.unsqueeze(0)) + + # input_ids, masks, labels + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + def s2s_gen_prompt( + self, + audio_usr_ids: list[torch.Tensor], + audio_asst_placeholders: list[torch.Tensor], + image_ids: Optional[list[Optional[torch.Tensor]]] = None, + ): + if len(audio_usr_ids) != len(audio_asst_placeholders): + raise ValueError("audio_usr_ids and audio_asst_placeholders must have the same length") + if image_ids is None: + image_ids = [None] * len(audio_usr_ids) + elif len(image_ids) != len(audio_usr_ids): + raise ValueError("image_ids length must match user audio list") + + device = audio_usr_ids[0].device + + task_token = self.sptids_dict['<|s2s|>'].to(device).view(-1) + soa_token = self.sptids_dict['<|soa|>'].to(device).view(-1) + eoa_token = self.sptids_dict['<|eoa|>'].to(device).view(-1) + soi_token = self.sptids_dict['<|soi|>'].to(device).view(-1) + eoi_token = self.sptids_dict['<|eoi|>'].to(device).view(-1) + + user_header = "<|start_header_id|>user<|end_header_id|>\n" + asst_header = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + u_tokens = self.text_tokenizer(user_header, return_tensors="pt").input_ids.to(device).view(-1) + a_tokens = self.text_tokenizer(asst_header, return_tensors="pt").input_ids.to(device).view(-1) + + sequences: list[torch.Tensor] = [] + attention_masks: list[torch.Tensor] = [] + + for usr_tokens, asst_placeholder, img_tokens in zip(audio_usr_ids, audio_asst_placeholders, image_ids): + usr_tokens = usr_tokens.to(device).long() + if usr_tokens.dim() > 1: + usr_tokens = usr_tokens.view(-1) + + asst_placeholder = asst_placeholder.to(device).long() + if asst_placeholder.dim() > 1: + asst_placeholder = asst_placeholder.view(-1) + + seq_parts = [task_token, u_tokens] + + if isinstance(img_tokens, (list, tuple)): + for seg in img_tokens: + if seg is None: + continue + seg = seg.to(device).long() + if seg.dim() > 1: + seg = seg.view(-1) + seq_parts.extend([soi_token, seg, eoi_token]) + elif img_tokens is not None: + seg = img_tokens.to(device).long() + if seg.dim() > 1: + seg = seg.view(-1) + seq_parts.extend([soi_token, seg, eoi_token]) + + seq_parts.extend([soa_token, usr_tokens, eoa_token, a_tokens, soa_token, asst_placeholder]) + + seq = torch.cat(seq_parts, dim=0) + attn_mask = torch.ones_like(seq, dtype=torch.long) + + sequences.append(seq.unsqueeze(0)) + attention_masks.append(attn_mask.unsqueeze(0)) + + return torch.cat(sequences, dim=0), torch.cat(attention_masks, dim=0) + + def v2s_gen_prompt( + self, + video_ids: Union[list[torch.Tensor], torch.Tensor], + text_ids: list[list[int]], + audio_placeholders: list[torch.Tensor], + ): + if len(text_ids) != len(audio_placeholders): + raise ValueError("text_ids and audio_placeholders must have the same length") + + if isinstance(video_ids, torch.Tensor): + video_list = [video_ids[i] for i in range(video_ids.shape[0])] + else: + video_list = video_ids + + device = audio_placeholders[0].device + + v2s_token = self.sptids_dict['<|v2s|>'].to(device).view(-1) + soi_token = self.sptids_dict['<|soi|>'].to(device).view(-1) + eoi_token = self.sptids_dict['<|eoi|>'].to(device).view(-1) + soa_token = self.sptids_dict['<|soa|>'].to(device).view(-1) + eoa_token = self.sptids_dict['<|eoa|>'].to(device).view(-1) + + max_text_len = self.max_text_len - 1 + max_audio_len = self.max_audio_len_short + eos_id = self.text_tokenizer.eos_token_id + + sequences: list[torch.Tensor] = [] + attention_masks: list[torch.Tensor] = [] + + for vid_tokens, txt_ids, audio_placeholder in zip(video_list, text_ids, audio_placeholders): + if len(txt_ids) == 0: + txt_ids = [self.text_tokenizer.bos_token_id] + elif txt_ids[0] != self.text_tokenizer.bos_token_id: + txt_ids = [self.text_tokenizer.bos_token_id] + txt_ids + + temp_text = txt_ids + [eos_id] + if len(temp_text) < max_text_len: + temp_text = temp_text + [eos_id] * (max_text_len - len(temp_text)) + else: + temp_text = temp_text[:max_text_len - 1] + [eos_id] + text_tensor = torch.tensor(temp_text, dtype=torch.long, device=device) + + if isinstance(vid_tokens, torch.Tensor): + vid_tensor = vid_tokens.to(device).long() + else: + vid_tensor = torch.tensor(vid_tokens, dtype=torch.long, device=device) + if vid_tensor.dim() > 1: + vid_tensor = vid_tensor.view(-1) + + audio_placeholder = audio_placeholder.to(device).long() + if audio_placeholder.dim() > 1: + audio_placeholder = audio_placeholder.view(-1) + + audio_block = torch.cat([soa_token, audio_placeholder], dim=0) + if audio_block.numel() > max_audio_len: + audio_block = audio_block[:max_audio_len] + elif audio_block.numel() < max_audio_len: + pad_len = max_audio_len - audio_block.numel() + if pad_len > 0: + pad_value = audio_placeholder.new_full((pad_len,), audio_placeholder[-1].item()) + audio_block = torch.cat([audio_block, pad_value], dim=0) + + seq = torch.cat([v2s_token, soi_token, vid_tensor, eoi_token, text_tensor, audio_block], dim=0) + attn_mask = torch.ones_like(seq, dtype=torch.long) + + sequences.append(seq.unsqueeze(0)) + attention_masks.append(attn_mask.unsqueeze(0)) + + return torch.cat(sequences, dim=0), torch.cat(attention_masks, dim=0) + + def mmu_prompt(self, image_ids, text_ids): + device = image_ids.device + sequence_ids = [] + prompt_masks = [] + label_ids = [] + max_text_len = self.max_text_len - 1 + eos_id = self.text_tokenizer.eos_token_id + eos_id = self.text_tokenizer.eos_token_id + for i in range(len(text_ids)): + # note that, llama3 tokenizer automatically add the bot token at first but without eot + # for empty list [] + + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id] + + if max_text_len >= len(temp_ids): + # minus 1 because task token was prepended to the former image tokens + temp_ids = temp_ids + [eos_id] * (max_text_len - len(temp_ids)) + temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids)) + else: + # should add the eos token + temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens + + # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi] + temp_label_ids = torch.cat([ + torch.tensor([self.ignore_id]).to(device), + torch.tensor([self.ignore_id]).to(device), + torch.ones_like(image_ids[i]) * self.ignore_id, + torch.tensor([self.ignore_id]).to(device), + torch.tensor(temp_ids).to(device), + ], dim=0) + + + return_temp_ids = torch.cat([ + self.sptids_dict['<|mmu|>'].to(device), # task token + self.sptids_dict['<|soi|>'].to(device), + image_ids[i], + self.sptids_dict['<|eoi|>'].to(device), + torch.tensor(temp_ids).to(device), + ], dim=0) + end_header_id = int(self.sptids_dict['<|end_header_id|>']) + end_header_pos = -1 + for pos in range(len(temp_ids) - 1, -1, -1): + if temp_ids[pos] == end_header_id: + end_header_pos = pos + break + if end_header_pos != -1: + prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1 + else: + prompt_length = len(return_temp_ids) - len(temp_ids) + predict_length = len(return_temp_ids) - prompt_length + prompt_mask = [1] * prompt_length + [0] * predict_length + prompt_mask = torch.tensor(prompt_mask).to(device) + sequence_ids.append(return_temp_ids.unsqueeze(0)) + prompt_masks.append(prompt_mask.unsqueeze(0)) + label_ids.append(temp_label_ids.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + def mmu_mult_prompt(self, batch_image_ids_list, batch_text_ids): + """ + Multi-image prompt builder (strict whole-image fit). + + INPUTS + ------- + batch_image_ids_list : List[List[torch.LongTensor]] + Length = B. + For sample i, a list of K_i encoded image token tensors, each shape (L_img_k,). + Example for one sample: [img_ids_0, img_ids_1, ...] + IMPORTANT: Images are already *encoded* to discrete token IDs. + + batch_text_ids : List[List[int]] + Length = B. + For sample i, raw tokenized text IDs (no BOS/EOS added here). + IMPORTANT: Text is *tokenized*, not encoded beyond text tokenizer IDs. + + RETURNS + ------- + sequence_ids : torch.LongTensor, shape (B, 1 + max_image_len + max_text_len) + The model input IDs: + [] + [ image_0 ... image_m ] + [text_ids (BOS...EOS padded)] + prompt_masks : torch.LongTensor, shape (B, 1 + max_image_len + max_text_len) + 1 = prompt/context, 0 = generation region + (split determined by last <|end_header_id|> inside the text segment; if none, entire text is generation) + label_ids : torch.LongTensor, shape (B, 1 + max_image_len + max_text_len) + self.ignore_id for specials & image tokens; text positions = text token IDs + """ + + B = len(batch_text_ids) + device = ( + batch_image_ids_list[0][0].device + if (batch_image_ids_list and batch_image_ids_list[0]) + else torch.device("cpu") + ) + + max_text_len = self.max_text_len - 1 + max_image_len = self.max_image_len + + # Text tokenizer ids + bos_id = self.text_tokenizer.bos_token_id + eos_id = self.text_tokenizer.eos_token_id + + # Specials (stored in self.sptids_dict as 1D LongTensors) + mmu_tok = self.sptids_dict['<|mmu|>'].to(device) # task token, shape (1,) + soi_tok = self.sptids_dict['<|soi|>'].to(device) # start of image + eoi_tok = self.sptids_dict['<|eoi|>'].to(device) # end of image + end_header_id = int(self.sptids_dict['<|end_header_id|>']) + + sequence_rows = [] + mask_rows = [] + label_rows = [] + + for i in range(B): + # ---------------- 1) Build image block under the image budget ---------------- + img_parts = [] + used_img_len = 0 + + for img_ids in batch_image_ids_list[i]: + img_ids = img_ids.to(device) + need = 2 + img_ids.numel() # + tokens + + if used_img_len + need <= max_image_len: + img_parts.extend([soi_tok, img_ids, eoi_tok]) + used_img_len += need + else: + # skip whole image if it doesn't fit (no partial slicing) + continue + + image_block = ( + torch.cat( + [p if isinstance(p, torch.Tensor) else torch.tensor([p], device=device) for p in img_parts], + dim=0 + ) + if img_parts else torch.empty((0,), dtype=torch.long, device=device) + ) + + # ---------------- 2) Prepare text to fill the remaining budget ---------------- + # Target per-sample total length: + # 1 (mmu) + max_image_len + max_text_len + # Prefix currently uses: 1 (mmu) + used_img_len + # So text must fill: text_budget = max_text_len + (max_image_len - used_img_len) + text_budget = max_text_len + (max_image_len - used_img_len) + + text_ids = batch_text_ids[i] + # Ensure BOS at start + if len(text_ids) == 0: + text_ids = [bos_id] + elif text_ids[0] != bos_id: + text_ids = [bos_id] + text_ids + + # Append EOS and pad/truncate to text_budget + tmp = text_ids + [eos_id] + if len(tmp) < text_budget: + tmp = tmp + [eos_id] * (text_budget - len(tmp)) + else: + tmp = tmp[:max(text_budget - 1, 0)] + ([eos_id] if text_budget > 0 else []) + + temp_ids = torch.tensor(tmp, dtype=torch.long, device=device) # text segment + + # ---------------- 3) Build final sequence ---------------- + prefix = torch.cat([mmu_tok, image_block], dim=0) # length = 1 + used_img_len + return_ids = torch.cat([prefix, temp_ids], dim=0) # length = 1 + used_img_len + text_budget + + # Enforce exact total length + expected_len = 1 + max_image_len + max_text_len + assert return_ids.numel() == expected_len, f"got {return_ids.numel()}, want {expected_len}" + + # ---------------- 4) Labels (ignore for specials/images; supervise text) ---------------- + ignore_prefix = torch.full((prefix.numel(),), self.ignore_id, dtype=torch.long, device=device) + temp_label_ids = torch.cat([ignore_prefix, temp_ids], dim=0) # same length as return_ids + + # ---------------- 5) Prompt mask ---------------- + # Find last <|end_header_id|> in the TEXT region; generation starts after it. + end_header_pos = -1 + for pos in range(temp_ids.numel() - 1, -1, -1): + if temp_ids[pos].item() == end_header_id: + end_header_pos = pos + break + + if end_header_pos != -1: + prompt_len = prefix.numel() + (end_header_pos + 1) + else: + # Match original behavior: if not found, prompt is only the prefix (images+mmu), + # and the entire text region is the generation target. + prompt_len = prefix.numel() + + predict_len = return_ids.numel() - prompt_len + prompt_mask = torch.cat([ + torch.ones(prompt_len, dtype=torch.long, device=device), + torch.zeros(max(predict_len, 0), dtype=torch.long, device=device) + ], dim=0) + + # ---------------- 6) Collect rows (keep original return structure) ---------------- + sequence_rows.append(return_ids.unsqueeze(0)) + mask_rows.append(prompt_mask.unsqueeze(0)) + label_rows.append(temp_label_ids.unsqueeze(0)) + + # cat along dim=0 + return torch.cat(sequence_rows, dim=0), torch.cat(mask_rows, dim=0), torch.cat(label_rows, dim=0) + + def v2t_prompt(self, image_ids, text_ids): + device = image_ids.device + sequence_ids = [] + prompt_masks = [] + label_ids = [] + max_text_len = self.max_text_len - 1 + eos_id = self.text_tokenizer.eos_token_id + for i in range(len(text_ids)): + # note that, llama3 tokenizer automatically add the bot token at first but without eot + # for empty list [] + + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id] + + if max_text_len >= len(temp_ids): + # minus 1 because task token was prepended to the former image tokens + temp_ids = temp_ids + [eos_id] * (max_text_len - len(temp_ids)) + temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids)) + else: + # should add the eos token + temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens + + # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi] + temp_label_ids = torch.cat([ + torch.tensor([self.ignore_id]).to(device), + torch.tensor([self.ignore_id]).to(device), + torch.ones_like(image_ids[i]) * self.ignore_id, + torch.tensor([self.ignore_id]).to(device), + torch.tensor(temp_ids).to(device), + ], dim=0) + + + return_temp_ids = torch.cat([ + self.sptids_dict['<|v2t|>'].to(device), # task token + self.sptids_dict['<|soi|>'].to(device), + image_ids[i], + self.sptids_dict['<|eoi|>'].to(device), + torch.tensor(temp_ids).to(device), + ], dim=0) + end_header_id = int(self.sptids_dict['<|end_header_id|>']) + end_header_pos = -1 + for pos in range(len(temp_ids) - 1, -1, -1): + if temp_ids[pos] == end_header_id: + end_header_pos = pos + break + if end_header_pos != -1: + prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1 + else: + prompt_length = len(return_temp_ids) - len(temp_ids) + predict_length = len(return_temp_ids) - prompt_length + prompt_mask = [1] * prompt_length + [0] * predict_length + prompt_mask = torch.tensor(prompt_mask).to(device) + sequence_ids.append(return_temp_ids.unsqueeze(0)) + prompt_masks.append(prompt_mask.unsqueeze(0)) + label_ids.append(temp_label_ids.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + def _v2s_prompt_impl( + self, + image_ids, + text_ids, + audio_ids, + supervise_padding: bool = True, + ): + """ + image_ids: list[torch.Tensor] or Tensor[B, L_img] + text_ids : list[list[int]] + audio_ids: list[torch.Tensor] # each shaped (1, L_audio) + """ + device = (image_ids[0].device if isinstance(image_ids, list) else image_ids.device) + sequence_ids, prompt_masks, label_ids = [], [], [] + + max_text_len = self.max_text_len - 1 + max_audio_len = self.max_audio_len_short + eos_id = self.text_tokenizer.eos_token_id + ignore_id = self.ignore_id + + B = len(text_ids) + for i in range(B): + # ---- Text normalize ---- + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + temp_text_ids = text_ids[i] + [eos_id] + if len(temp_text_ids) < max_text_len: + temp_text_ids = temp_text_ids + [eos_id] * (max_text_len - len(temp_text_ids)) + else: + temp_text_ids = temp_text_ids[:max_text_len - 1] + [eos_id] + text_tensor = torch.tensor(temp_text_ids, dtype=torch.long, device=device) + + # ---- Audio block with /, clamp/pad to max_audio_len ---- + soa = self.sptids_dict['<|soa|>'].to(device).unsqueeze(0) # (1,1) + eoa = self.sptids_dict['<|eoa|>'].to(device).unsqueeze(0) # (1,1) + audio_block = torch.cat([soa, audio_ids[i], eoa], dim=1) # (1, L+2) + pre_pad_len = audio_block.shape[1] + actual_len = min(pre_pad_len, max_audio_len) + + audio_block = audio_block[:, :actual_len] + if pre_pad_len > max_audio_len and audio_block[0, -1] != eoa[0]: + audio_block[0, -1] = eoa[0] + + pad_len = max_audio_len - audio_block.shape[1] + if pad_len > 0: + pad = torch.full((1, pad_len), eos_id, dtype=torch.long, device=device) + audio_block = torch.cat([audio_block, pad], dim=1) + + # ---- Sequence: <|v2s|>, <|soi|>, image, <|eoi|>, text, <|soa|>, audio..., <|eoa|>, [pads...] ---- + v2s = self.sptids_dict['<|v2s|>'].to(device).to(torch.long) + soi = self.sptids_dict['<|soi|>'].to(device).to(torch.long) + eoi = self.sptids_dict['<|eoi|>'].to(device).to(torch.long) + + img_tokens = image_ids[i] if isinstance(image_ids, list) else image_ids[i] + img_tokens = img_tokens.to(device).to(torch.long) + + seq = torch.cat([ + v2s, soi, img_tokens, eoi, text_tensor, audio_block.squeeze(0).to(torch.long) + ], dim=0) + + # ---- Prompt mask: 1 through and including , then 0 over audio targets ---- + prompt_length = 1 + 1 + img_tokens.shape[-1] + 1 + len(text_tensor) + 1 # +1 for + total_length = seq.shape[0] + predict_length = actual_len - 1 # exclude + padding_region = total_length - prompt_length - predict_length + + tail_mask_value = 0 if supervise_padding else 1 + prompt_mask = torch.tensor( + [1]*prompt_length + + [0]*predict_length + + [tail_mask_value]*padding_region, + dtype=torch.long, + device=device, + ) + + # ---- Labels: ignore prompt, then audio after ---- + audio_targets = audio_block.squeeze(0)[1:actual_len].to(torch.long) + if supervise_padding: + padding_labels = audio_block.squeeze(0)[actual_len:].to(torch.long) + else: + padding_labels = torch.full((padding_region,), ignore_id, dtype=torch.long, device=device) + + label = torch.cat([ + torch.full((prompt_length,), ignore_id, dtype=torch.long, device=device), + audio_targets, + padding_labels, + ], dim=0) + + # ---- Sanity checks ---- + assert audio_block.shape[1] == max_audio_len + assert total_length == prompt_length + (max_audio_len - 1) # targets exclude + assert label.shape[0] == total_length + assert prompt_mask.shape[0] == total_length + + sequence_ids.append(seq.unsqueeze(0)) + prompt_masks.append(prompt_mask.unsqueeze(0)) + label_ids.append(label.unsqueeze(0)) + + return ( + torch.cat(sequence_ids, dim=0), + torch.cat(prompt_masks, dim=0), + torch.cat(label_ids, dim=0), + ) + + def v2s_prompt(self, image_ids, text_ids, audio_ids): + return self._v2s_prompt_impl(image_ids, text_ids, audio_ids, supervise_padding=True) + + def v2s_prompt_ignore_padding(self, image_ids, text_ids, audio_ids): + return self._v2s_prompt_impl(image_ids, text_ids, audio_ids, supervise_padding=False) + + + # def s2t_prompt(self, audio_ids, text_ids): + # device = audio_ids.device + # sequence_ids = [] + # prompt_masks = [] + # label_ids = [] + # max_text_len = self.max_text_len - 1 + # for i in range(len(text_ids)): + # # note that, llama3 tokenizer automatically add the bot token at first but without eot + # # for empty list [] + + # if len(text_ids[i]) == 0: + # text_ids[i] = [self.text_tokenizer.bos_token_id] + # elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + # text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + # temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id] + + # if max_text_len >= len(temp_ids): + # # minus 1 because task token was prepended to the former audio tokens + # temp_ids = temp_ids + [self.pad_id] * (max_text_len - len(temp_ids)) + # temp_masks = [1] * (len(temp_ids) + audio_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids)) # NOTE: left or right padding? ė‘˜ 중에 ķ•˜ė‚˜ė§Œ ź³Øė¼ + # else: + # # should add the eos token + # temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + # tempmasks = [1] * (len(temp_ids) + audio_ids.shape[-1] + 3) # +2 for two special tokens + + # # prompting -- [task token] [sot] [text tokens] [eot] [soa] [audio tokens] [eoa] + # temp_label_ids = torch.cat([ + # torch.tensor([self.ignore_id]).to(device), + # torch.tensor([self.ignore_id]).to(device), + # torch.ones_like(audio_ids[i]) * self.ignore_id, + # torch.tensor([self.ignore_id]).to(device), + # torch.tensor(temp_ids).to(device), + # ], dim=0) + + + # return_temp_ids = torch.cat([ + # self.sptids_dict['<|s2t|>'].to(device), # task token + # self.sptids_dict['<|soa|>'].to(device), + # audio_ids[i], + # self.sptids_dict['<|eoa|>'].to(device), + # torch.tensor(temp_ids).to(device), + # ], dim=0) + # end_header_id = int(self.sptids_dict['<|end_header_id|>']) + # end_header_pos = -1 + # for pos in range(len(temp_ids) - 1, -1, -1): + # if temp_ids[pos] == end_header_id: + # end_header_pos = pos + # break + # if end_header_pos != -1: + # prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1 + # else: + # prompt_length = len(return_temp_ids) - len(temp_ids) + # predict_length = len(return_temp_ids) - prompt_length + # prompt_mask = [1] * prompt_length + [0] * predict_length + # prompt_mask = torch.tensor(prompt_mask).to(device) + # sequence_ids.append(return_temp_ids.unsqueeze(0)) + # prompt_masks.append(prompt_mask.unsqueeze(0)) + # label_ids.append(temp_label_ids.unsqueeze(0)) + + # return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + # def t2s_prompt(self, text_ids, audio_ids): + + # device = audio_ids.device + # sequence_ids = [] + # prompt_masks = [] + # label_ids = [] + # probs = torch.rand(len(text_ids)) + + # for i in range(len(text_ids)): + # # [text tokens] -> [task token] [bos] [text tokens] [eos] + # if len(text_ids[i]) == 0: + # text_ids[i] = [self.text_tokenizer.bos_token_id] + # elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + # text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + # temp_ids = [int(self.sptids_dict['<|t2s|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id] + + # # randomly dropout text condition + # if probs[i] < self.cond_dropout_prob: + # temp_ids = [int(self.sptids_dict['<|t2s|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id] + + # # padding + # if self.max_text_len >= len(temp_ids): + # old_len = len(temp_ids) + # temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids + # temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + audio_ids.shape[-1] + 2) + # else: + # # should add the eos token + # temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id] + # temp_masks = [1] * (len(temp_ids) + audio_ids.shape[-1] + 2) # +2 for two special tokens + + # text_part_len = len(temp_ids) + + # # prompting -- [pad] [task token] [bos] [text tokens] [eos] [soa] [audio tokens] [eoa] + # seq = torch.cat([ + # torch.tensor(temp_ids).to(device), + # self.sptids_dict['<|soa|>'].to(device), + # audio_ids[i], + # self.sptids_dict['<|eoa|>'].to(device) + # ], dim=0) + + # prompt_part_len = text_part_len + 1 # + 1 for <|soa|> + # audio_part_len = audio_ids[i].shape[-1] + 1 # +1 for <|eoa|> + # prompt_mask = [1] * prompt_part_len + [0] * audio_part_len + # prompt_mask = torch.tensor(prompt_mask).to(device) + + # label = torch.cat([ + # torch.full((prompt_part_len,), self.ignore_id, device=device), + # audio_ids[i], + # torch.tensor([self.ignore_id], device=device) + # ], dim=0) + + # # sequence_ids: [pad]...[pad] <|t2s|> <|bos|> text_1 ... text_n <|eos|> <|soa|> audio_1 ... audio_m <|eoa|> + # sequence_ids.append(seq.unsqueeze(0)) + # prompt_masks.append(prompt_mask.unsqueeze(0)) + # label_ids.append(label.unsqueeze(0)) + + # return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + def s2t_prompt(self, audio_ids, text_ids): + + device = audio_ids[0].device + + sequence_ids, prompt_masks, label_ids = [], [], [] + max_text_len = self.max_text_len - 1 + max_audio_len = self.max_audio_len + 1 + eos_id = self.text_tokenizer.eos_token_id + + for i in range(len(text_ids)): + task_tensor = self.sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = self.sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = self.sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + current_audio_tokens = audio_ids[i] + + # (<|s2t|>, <|soa|>, <|eoa|>) + effective_max_audio = max_audio_len - 3 + if current_audio_tokens.shape[1] > effective_max_audio: + current_audio_tokens = current_audio_tokens[:, :effective_max_audio] + + audio_block = torch.cat([task_tensor, soa_tensor, current_audio_tokens, eoa_tensor], dim=1) + + num_padding = max_audio_len - audio_block.shape[1] + if num_padding > 0: + padding_tensor = torch.full((1, num_padding), self.pad_id, dtype=torch.long, device=device) + padded_audio_block = torch.cat([padding_tensor, audio_block], dim=1) + else: + padded_audio_block = audio_block + + padded_audio_block_len = padded_audio_block.shape[1] + + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id] + + if max_text_len >= len(temp_ids): + temp_ids = temp_ids + [eos_id] * (max_text_len - len(temp_ids)) + else: + temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + + return_temp_ids = torch.cat([ + padded_audio_block.squeeze(0), + torch.tensor(temp_ids, device=device), + ], dim=0) + + + prompt_length = padded_audio_block_len + temp_label_ids = torch.cat([ + torch.full((prompt_length,), self.ignore_id, device=device), + torch.tensor(temp_ids, device=device), + ], dim=0) + + end_header_id = int(self.sptids_dict['<|end_header_id|>']) + end_header_pos = -1 + for pos in range(len(temp_ids) - 1, -1, -1): + if temp_ids[pos] == end_header_id: + end_header_pos = pos; break + + if end_header_pos != -1: + final_prompt_length = prompt_length + end_header_pos + 1 + else: + final_prompt_length = prompt_length + + prompt_mask = torch.tensor([1] * final_prompt_length + [0] * (len(return_temp_ids) - final_prompt_length), device=device) + + sequence_ids.append(return_temp_ids.unsqueeze(0)) + prompt_masks.append(prompt_mask.unsqueeze(0)) + label_ids.append(temp_label_ids.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + def t2s_prompt(self, text_ids, audio_ids): + """ + text_ids: list[list[int]] + audio_ids: list[torch.Tensor] + """ + + device = audio_ids[0].device + + audio_pad_token = self.text_tokenizer.eos_token_id + max_text_len = self.max_text_len + max_audio_len = self.max_audio_len + + sequence_ids, prompt_masks, label_ids = [], [], [] + probs = torch.rand(len(text_ids)) + + for i in range(len(text_ids)): + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + temp_ids = [int(self.sptids_dict['<|t2s|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id] + + if probs[i] < self.cond_dropout_prob: + temp_ids = [int(self.sptids_dict['<|t2s|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id] + + if max_text_len >= len(temp_ids): + temp_ids = [self.pad_id] * (max_text_len- len(temp_ids)) + temp_ids + else: + temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + text_part_len = len(temp_ids) + + soa_tensor = self.sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = self.sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([soa_tensor, audio_ids[i], eoa_tensor], dim=1) + + if audio_block.shape[1] > max_audio_len: + audio_block = audio_block[:, :max_audio_len] + if audio_block[0, -1] != eoa_tensor[0]: + audio_block[0, -1] = eoa_tensor[0] + + num_padding = max_audio_len - audio_block.shape[1] + if num_padding > 0: + padding_tensor = torch.full((1, num_padding), audio_pad_token, dtype=torch.long, device=device) + padded_audio_ids = torch.cat([audio_block, padding_tensor], dim=1) + else: + padded_audio_ids = audio_block + + seq = torch.cat([ + torch.tensor(temp_ids, device=device), + padded_audio_ids.squeeze(0) + ], dim=0) + + prompt_part_len = text_part_len + 1 # add 1 for + audio_part_len = max_audio_len - 1 # subsitude 1 for + prompt_mask = torch.tensor([1] * prompt_part_len + [0] * audio_part_len, device=device) + + label = torch.cat([ + torch.full((prompt_part_len,), self.ignore_id, device=device), # ignore up to + padded_audio_ids.squeeze(0)[1:] # delete token + ], dim=0) + + sequence_ids.append(seq.unsqueeze(0)) + prompt_masks.append(prompt_mask.unsqueeze(0)) + label_ids.append(label.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + + def t2s_prompt_ignore_padding(self, text_ids, audio_ids): + + device = audio_ids[0].device + + audio_pad_token = self.text_tokenizer.eos_token_id + max_text_len = self.max_text_len + max_audio_len = self.max_audio_len + + sequence_ids, prompt_masks, label_ids = [], [], [] + probs = torch.rand(len(text_ids)) + + for i in range(len(text_ids)): + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + temp_ids = [int(self.sptids_dict['<|t2s|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id] + + if probs[i] < self.cond_dropout_prob: + temp_ids = [int(self.sptids_dict['<|t2s|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id] + + if max_text_len >= len(temp_ids): + temp_ids = [self.pad_id] * (max_text_len- len(temp_ids)) + temp_ids + else: + temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + text_part_len = len(temp_ids) + + soa_tensor = self.sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = self.sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([soa_tensor, audio_ids[i], eoa_tensor], dim=1) + + if audio_block.shape[1] > max_audio_len: + audio_block = audio_block[:, :max_audio_len] + if audio_block[0, -1] != eoa_tensor[0]: + audio_block[0, -1] = eoa_tensor[0] + + num_padding = max_audio_len - audio_block.shape[1] + if num_padding > 0: + padding_tensor = torch.full((1, num_padding), audio_pad_token, dtype=torch.long, device=device) + padded_audio_ids = torch.cat([audio_block, padding_tensor], dim=1) + else: + padded_audio_ids = audio_block + + # Full input sequence + seq = torch.cat([ + torch.tensor(temp_ids, device=device), + padded_audio_ids.squeeze(0) + ], dim=0) + + # Compute lengths for masking/labels + prompt_part_len = text_part_len + 1 # include + actual_audio_target_len = audio_block.shape[1] - 1 # exclude , include audio VQ codes and + padded_region_len = num_padding # trailing padding tokens after audio_block + + # Prompt mask: do NOT mask text, , or trailing padding; mask only real audio targets + prompt_mask = torch.tensor( + [1] * prompt_part_len + [0] * actual_audio_target_len + [1] * padded_region_len, + device=device, + ) + + # Labels: ignore prompt and padding regions; supervise only real audio targets + label = torch.cat([ + torch.full((prompt_part_len,), self.ignore_id, device=device), + audio_block.squeeze(0)[1:], # drop + torch.full((padded_region_len,), self.ignore_id, device=device), + ], dim=0) + + sequence_ids.append(seq.unsqueeze(0)) + prompt_masks.append(prompt_mask.unsqueeze(0)) + label_ids.append(label.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + def _s2s_prompt_impl( + self, + audio_usr_ids: list[torch.Tensor], + audio_asst_ids: list[torch.Tensor], + image_ids: Optional[list[Optional[torch.Tensor]]] = None, + supervise_padding: bool = False, + ): + if len(audio_usr_ids) != len(audio_asst_ids): + raise ValueError("audio_usr_ids and audio_asst_ids must have the same length") + + if image_ids is None: + image_ids = [None] * len(audio_usr_ids) + elif len(image_ids) != len(audio_usr_ids): + raise ValueError("image_ids length must match audio lists") + + if len(audio_usr_ids) == 0: + raise ValueError("s2s_prompt requires at least one sample") + + device = audio_usr_ids[0].device + + task_tensor = self.sptids_dict['<|s2s|>'].to(device).unsqueeze(0) + soa_tensor = self.sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = self.sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + soi_tensor = self.sptids_dict['<|soi|>'].to(device).unsqueeze(0) + eoi_tensor = self.sptids_dict['<|eoi|>'].to(device).unsqueeze(0) + + user_header = "<|start_header_id|>user<|end_header_id|>\n" + asst_header = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + u_tokens = self.text_tokenizer(user_header, return_tensors="pt").input_ids.to(device) + a_tokens = self.text_tokenizer(asst_header, return_tensors="pt").input_ids.to(device) + + left_pad_id = self.pad_id + right_pad_id = self.text_tokenizer.eos_token_id + max_audio_len = self.max_audio_len_short + + sequence_ids = [] + prompt_masks = [] + label_ids = [] + + task_len = task_tensor.shape[1] + soa_len = soa_tensor.shape[1] + eoa_len = eoa_tensor.shape[1] + user_header_len = u_tokens.shape[1] + asst_header_len = a_tokens.shape[1] + + for usr_tokens, asst_tokens, img_tokens in zip(audio_usr_ids, audio_asst_ids, image_ids): + if usr_tokens.device != device: + usr_tokens = usr_tokens.to(device) + if asst_tokens.device != device: + asst_tokens = asst_tokens.to(device) + if usr_tokens.dim() == 1: + usr_tokens = usr_tokens.unsqueeze(0) + if asst_tokens.dim() == 1: + asst_tokens = asst_tokens.unsqueeze(0) + usr_tokens = usr_tokens.long() + asst_tokens = asst_tokens.long() + + image_block = None + image_block_len = 0 + if isinstance(img_tokens, (list, tuple)): + image_segments = [] + for segment in img_tokens: + if segment is None: + continue + segment = segment.to(device) + if segment.dim() == 1: + segment = segment.unsqueeze(0) + segment = segment.long() + seg_block = torch.cat([soi_tensor, segment, eoi_tensor], dim=1) + image_segments.append(seg_block) + if image_segments: + image_block = torch.cat(image_segments, dim=1) + image_block_len = image_block.shape[1] + elif img_tokens is not None: + img_tokens = img_tokens.to(device) + if img_tokens.dim() == 1: + img_tokens = img_tokens.unsqueeze(0) + img_tokens = img_tokens.long() + image_block = torch.cat([soi_tensor, img_tokens, eoi_tensor], dim=1) + image_block_len = image_block.shape[1] + + max_usr_audio = max_audio_len - (task_len + user_header_len + soa_len + eoa_len) + max_usr_audio = max(0, max_usr_audio) + if usr_tokens.shape[1] > max_usr_audio: + usr_tokens = usr_tokens[:, :max_usr_audio] + + usr_parts = [task_tensor, u_tokens] + if image_block is not None: + usr_parts.append(image_block) + usr_parts.extend([soa_tensor, usr_tokens, eoa_tensor]) + usr_block = torch.cat(usr_parts, dim=1) + + target_usr_len = max_audio_len + image_block_len + num_usr_pad = target_usr_len - usr_block.shape[1] + if num_usr_pad < 0: + num_usr_pad = 0 + if num_usr_pad > 0: + usr_block = torch.cat([ + torch.full((1, num_usr_pad), left_pad_id, dtype=torch.long, device=device), + usr_block + ], dim=1) + + max_asst_audio = max_audio_len - (asst_header_len + soa_len + eoa_len) + max_asst_audio = max(0, max_asst_audio) + if asst_tokens.shape[1] > max_asst_audio: + asst_tokens = asst_tokens[:, :max_asst_audio] + + asst_block = torch.cat([a_tokens, soa_tensor, asst_tokens, eoa_tensor], dim=1) + target_asst_len = max_audio_len + num_asst_pad = target_asst_len - asst_block.shape[1] + if num_asst_pad < 0: + num_asst_pad = 0 + if num_asst_pad > 0: + asst_block = torch.cat([ + asst_block, + torch.full((1, num_asst_pad), right_pad_id, dtype=torch.long, device=device) + ], dim=1) + + seq = torch.cat([usr_block, asst_block], dim=1) + + prefix_len = usr_block.shape[1] + asst_header_len + soa_len + target_len = asst_tokens.shape[1] + eoa_len + padding_len = asst_block.shape[1] - (asst_header_len + soa_len + target_len) + + mask_segments = [ + torch.ones((prefix_len,), device=device, dtype=torch.long), + torch.zeros((target_len,), device=device, dtype=torch.long) + ] + if padding_len > 0: + pad_mask_value = 0 if supervise_padding else 1 + mask_segments.append(torch.full((padding_len,), pad_mask_value, device=device, dtype=torch.long)) + prompt_mask = torch.cat(mask_segments, dim=0) + + labels = seq.clone() + mask_bool = prompt_mask.bool().unsqueeze(0) + labels[mask_bool] = self.ignore_id + + sequence_ids.append(seq) + prompt_masks.append(prompt_mask.unsqueeze(0)) + label_ids.append(labels) + + if len(sequence_ids) > 1: + max_seq_len = max(seq.shape[1] for seq in sequence_ids) + if any(seq.shape[1] != max_seq_len for seq in sequence_ids): + padded_sequences = [] + padded_masks = [] + padded_labels = [] + for seq, mask, labels in zip(sequence_ids, prompt_masks, label_ids): + seq_pad = max_seq_len - seq.shape[1] + if seq_pad > 0: + seq = torch.nn.functional.pad(seq, (0, seq_pad), value=right_pad_id) + mask = torch.nn.functional.pad(mask, (0, seq_pad), value=1) + pad_labels = torch.full((labels.shape[0], seq_pad), self.ignore_id, dtype=labels.dtype, device=device) + labels = torch.cat([labels, pad_labels], dim=1) + padded_sequences.append(seq) + padded_masks.append(mask) + padded_labels.append(labels) + sequence_ids = padded_sequences + prompt_masks = padded_masks + label_ids = padded_labels + + return ( + torch.cat(sequence_ids, dim=0), + torch.cat(prompt_masks, dim=0), + torch.cat(label_ids, dim=0), + ) + + def s2s_prompt( + self, + audio_usr_ids: list[torch.Tensor], + audio_asst_ids: list[torch.Tensor], + image_ids: Optional[list[Optional[torch.Tensor]]] = None, + ): + return self._s2s_prompt_impl(audio_usr_ids, audio_asst_ids, image_ids, supervise_padding=False) + + def s2s_prompt_eos( + self, + audio_usr_ids: list[torch.Tensor], + audio_asst_ids: list[torch.Tensor], + image_ids: Optional[list[Optional[torch.Tensor]]] = None, + ): + return self._s2s_prompt_impl(audio_usr_ids, audio_asst_ids, image_ids, supervise_padding=True) + + def s2s_prompt_ignore_padding(self, audio_usr_ids: list[torch.LongTensor], audio_asst_ids: list[torch.LongTensor]): + """ + Args: + audio_usr_ids: list[torch.LongTensor], each elem is of shape (1, S), S is seq_len + audio_asst_ids: list[torch.LongTensor], each elem is of shape (1, S), S is seq_len + Returns: + sequence_ids: torch.LongTensor, of shape (B, L) + prompt_masks: torch.LongTensor, of shape (B, L) + label_ids: torch.LongTensor, of shape (B, L) + """ + device = audio_usr_ids[0].device + sequence_ids, prompt_masks, label_ids = [], [], [] + + # Pad tokens + left_pad_id = self.pad_id + right_pad_id = self.text_tokenizer.eos_token_id + + # Task and special tokens + task_tensor = self.sptids_dict['<|s2s|>'].to(device).unsqueeze(0) + soa_tensor = self.sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = self.sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + + # Headers for instruction tuning + u = "<|start_header_id|>user<|end_header_id|>\n" + a = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + u_tokens = self.text_tokenizer(u, return_tensors="pt")['input_ids'] + a_tokens = self.text_tokenizer(a, return_tensors="pt")['input_ids'] + + # Maximum lengths + max_usr_len = self.max_audio_len # TODO: how to adjust + max_asst_len = self.max_audio_len # TODO: how to adjust + + for i in range(len(audio_usr_ids)): + + # User tokens (truncation and left padding) + current_usr_tokens = audio_usr_ids[i] + + effective_max_audio = max_usr_len - (task_tensor.shape[1] + u_tokens.shape[1] + soa_tensor.shape[1] + eoa_tensor.shape[1]) + if current_usr_tokens.shape[1] > effective_max_audio: + current_usr_tokens = current_usr_tokens[:, :effective_max_audio] + + usr_block = torch.cat([task_tensor, u_tokens, soa_tensor, current_usr_tokens, eoa_tensor], dim=1) + + num_padding = max_usr_len - usr_block.shape[1] + if num_padding > 0: + padding_tensor = torch.full((1, num_padding), left_pad_id, dtype=torch.long, device=device) + padded_usr_block = torch.cat([padding_tensor, usr_block], dim=1) + else: + padded_usr_block = usr_block + + # Assistant tokens (truncation and right padding) + asst_block = torch.cat([a_tokens, soa_tensor, audio_asst_ids[i], eoa_tensor], dim=1) + + if asst_block.shape[1] > max_asst_len: + asst_block = asst_block[:, :max_asst_len] + asst_block[0, -1] = eoa_tensor[0] + + num_padding = max_asst_len - asst_block.shape[1] + if num_padding > 0: + padding_tensor = torch.full((1, num_padding), right_pad_id, dtype=torch.long, device=device) + padded_asst_block = torch.cat([asst_block, padding_tensor], dim=1) + else: + padded_asst_block = asst_block + + # Full sequence + seq = torch.cat([ + padded_usr_block, + padded_asst_block, + ], dim=1) + + # Mask and labels + prefix_mask_len = max_usr_len + (a_tokens.shape[1] + soa_tensor.shape[1]) # padded usr block + asst headers + + actual_audio_target_len = asst_block.shape[1] - (a_tokens.shape[1] + soa_tensor.shape[1]) # exclude asst headers and but incldue audio tokens and + prompt_mask = torch.tensor( + [1] * prefix_mask_len + [0] * actual_audio_target_len + [1] * (seq.shape[1] - prefix_mask_len - actual_audio_target_len), + device=device, dtype=torch.long).unsqueeze(0) + + labels = seq.clone() + labels[prompt_mask.bool()] = self.ignore_id + + # Append to lists + sequence_ids.append(seq) + prompt_masks.append(prompt_mask) + label_ids.append(labels) + + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0) + + + def mmu_gen_prompt(self, image_ids, text_ids): + device = image_ids.device + sequence_ids = [] + prompt_masks = [] + max_text_len = self.max_text_len - 1 + for i in range(len(text_ids)): + + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0] != self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + + temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id] + + if max_text_len >= len(temp_ids): + # minus 1 because task token was prepended to the former image tokens + temp_ids = temp_ids + [self.pad_id] * (max_text_len - len(temp_ids)) + else: + # should add the eos token + temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + + # print(f"mmu temp_ids: {temp_ids}") + return_temp_ids = torch.cat([ + self.sptids_dict['<|mmu|>'].to(device), # task token + self.sptids_dict['<|soi|>'].to(device), + image_ids[i], + self.sptids_dict['<|eoi|>'].to(device), + torch.tensor(temp_ids).to(device), + ], dim=0) + + end_header_id = int(self.sptids_dict['<|end_header_id|>']) + end_header_pos = -1 + for pos in range(len(temp_ids) - 1, -1, -1): + if temp_ids[pos] == end_header_id: + end_header_pos = pos + break + if end_header_pos != -1: + prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1 + else: + prompt_length = len(return_temp_ids) - len(temp_ids) + predict_length = len(temp_ids) - prompt_length + print(f"prompt_length: {prompt_length}, predict_length: {predict_length}, all length: {len(return_temp_ids)}, {return_temp_ids[-predict_length:]}") + prompt_mask = [1] * prompt_length + [0] * predict_length + prompt_mask = torch.tensor(prompt_mask).to(device) + sequence_ids.append(return_temp_ids.unsqueeze(0)) + prompt_masks.append(prompt_mask.unsqueeze(0)) + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0) + + def r2i_prompt(self, image_ids, text_ids): + device = image_ids.device + sequence_ids = [] + prompt_masks = [] + label_ids = [] + r2i_id = int(self.sptids_dict['<|r2i|>']) + soi_id = int(self.sptids_dict['<|soi|>']) + eoi_id = int(self.sptids_dict['<|eoi|>']) + max_text_len = self.max_text_len - 1 # 512,include BOS text EOS + for i in range(len(text_ids)): + # note that, llama3 tokenizer automatically add the bot token at first but without eot + # for empty list [] + if len(text_ids[i]) == 0: + text_ids[i] = [self.text_tokenizer.bos_token_id] + elif text_ids[i][0]!= self.text_tokenizer.bos_token_id: + text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] + text_ids_with_bos_eos = text_ids[i] + [self.text_tokenizer.eos_token_id] + if max_text_len >= len(text_ids_with_bos_eos): + # minus 1 because task token was prepended to the former image tokens + text_ids_full_len = text_ids_with_bos_eos + [self.text_tokenizer.eos_token_id] * (max_text_len - len(text_ids_with_bos_eos)) + else: + # should add the eos token + text_ids_full_len = text_ids_with_bos_eos[:max_text_len - 1] + [self.text_tokenizer.eos_token_id] + + sequence_ids.append(torch.cat([ + torch.tensor([r2i_id]).to(device), # task token + torch.tensor(text_ids_full_len).to(device), + torch.tensor([soi_id]).to(device), + image_ids[i], + torch.tensor([eoi_id]).to(device), + ], dim=0).unsqueeze(0)) + + end_header_id = int(self.sptids_dict['<|end_header_id|>']) + end_header_pos = -1 + for pos in range(len(text_ids_full_len) - 1, -1, -1): + if text_ids_full_len[pos] == end_header_id: + end_header_pos = pos + break + prompt_mask = torch.zeros(sequence_ids[i].size(1)).to(device) + prompt_mask[0] = 1 # task_id + if end_header_pos != -1: + prompt_mask[1:end_header_pos+2] = 1 + else: + prompt_mask[1:len(text_ids_full_len)+1] = 1 + prompt_mask[len(text_ids_full_len)+1] = 1 + prompt_mask[len(text_ids_full_len)+2+len(image_ids[i])] = 1 + prompt_masks.append(prompt_mask.unsqueeze(0)) + + return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(sequence_ids, dim=0) + + def mask_prompt(self): + pass + + def __call__(self, input, task, padding=True, config=None): + """ + input (tuple) : data pairs contain text(str), image(tensor), or videos(tensor). + task (str) : a flag indicates the current task. + """ + if task == "t2i": + text_ids = self.text_tokenizer( + input[0], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] # (B, max_len) + image_ids = input[1] # (B, #tokens) + sequence_ids_with_masks = self.t2i_prompt(text_ids, image_ids, input[2]) + + elif task == "i2i": + text_ids = input[0] + original_image_ids = input[1] # (B, #tokens) + edited_image_ids = input[2] # (B, #tokens) + sequence_ids_with_masks = self.i2i_prompt(text_ids, original_image_ids, edited_image_ids, input[3]) + + elif task == "s2t": + image_ids = input[0] + text_ids = self.text_tokenizer( + input[1], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] + sequence_ids_with_masks = self.s2t_prompt(image_ids, text_ids) + + elif task == "t2s": + audio_ids = input[1] + text_ids = self.text_tokenizer( + input[0], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] + sequence_ids_with_masks = self.t2s_prompt(text_ids, audio_ids) + + elif task == "t2s_ip": + audio_ids = input[1] + text_ids = self.text_tokenizer( + input[0], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] + sequence_ids_with_masks = self.t2s_prompt_ignore_padding(text_ids, audio_ids) + + elif task == "s2s_ip": + audio_user_ids = input[0] + audio_asst_ids = input[1] + image_ids = input[2] if len(input) > 2 else None + sequence_ids_with_masks = self.s2s_prompt(audio_user_ids, audio_asst_ids, image_ids) + + elif task == "s2s": + audio_user_ids = input[0] + audio_asst_ids = input[1] + image_ids = input[2] if len(input) > 2 else None + sequence_ids_with_masks = self.s2s_prompt_eos(audio_user_ids, audio_asst_ids, image_ids) + + # ------- WIP by yejoon ------- + elif task == "s2s_ip": + audio_user_ids = input[0] + audio_asst_ids = input[1] + sequence_ids_with_masks = self.s2s_prompt_ignore_padding(audio_user_ids, audio_asst_ids) + + elif task == "t2v": + text_ids = self.text_tokenizer( + input[0], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] # (B, max_len) + image_ids = input[1] # (B, #tokens) + sequence_ids_with_masks = self.t2v_prompt(text_ids, image_ids, input[2]) + + elif task == "t2i_plus_lm": + text_ids = self.text_tokenizer( + input[0], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] # (B, max_len) + image_ids = input[1] # (B, #tokens) + sequence_ids_with_masks = self.t2i_prompt(text_ids[:config.training.batch_size], image_ids, + input[2]) + sequence_ids_with_masks_lm = self.lm_prompt(text_ids[config.training.batch_size:], input[3]) + return sequence_ids_with_masks, sequence_ids_with_masks_lm + + elif task == "t2i_gen": + text_ids = self.text_tokenizer( + input[0], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] # (B, max_len) + image_ids = input[1] # (B, #tokens) + sequence_ids_with_masks = self.t2i_gen_prompt(text_ids, image_ids) + + elif task == "t2v_gen": + text_ids = self.text_tokenizer( + input[0], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] # (B, max_len) + image_ids = input[1] # (B, #tokens) + sequence_ids_with_masks = self.t2v_gen_prompt(text_ids, image_ids) + + elif task == "lm": + text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len) + sequence_ids_with_masks = self.lm_prompt(text_ids, input[1]) + + elif task == "lm_chat": + text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len) + sequence_ids_with_masks = self.lm_chat_prompt(text_ids, input[1]) + + # elif task == "mmu": + # image_ids = input[0] + # text_ids = self.text_tokenizer(input[1])['input_ids'] + # sequence_ids_with_masks = self.mmu_prompt(image_ids, text_ids) + + elif task == "mmu": + text_ids = self.text_tokenizer(input[1], truncation=True)['input_ids'] # (B, max_len) + + sequence_ids_with_masks = self.mmu_mult_prompt( + batch_image_ids_list=input[0], + batch_text_ids=text_ids, + ) + + elif task == "v2t": + video_ids = input[0] + text_ids = self.text_tokenizer( + input[1], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] + sequence_ids_with_masks = self.v2t_prompt(video_ids, text_ids) + + elif task == 'v2s': + video_ids = input[0] + text_ids = self.text_tokenizer( + input[1], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] + audio_ids = input[2] + sequence_ids_with_masks = self.v2s_prompt(video_ids, text_ids, audio_ids) + + elif task == 'v2s_ip': + video_ids = input[0] + text_ids = self.text_tokenizer( + input[1], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] + audio_ids = input[2] + sequence_ids_with_masks = self.v2s_prompt_ignore_padding(video_ids, text_ids, audio_ids) + + elif task == "r2i": + image_ids = input[0] + text_ids = self.text_tokenizer(input[1])['input_ids'] + sequence_ids_with_masks = self.r2i_prompt(image_ids, text_ids) + + elif task == "i2i_gen": + text_ids = input[0] + input_image_ids = input[1] + output_image_ids = input[2] + sequence_ids_with_masks = self.i2i_gen_prompt(text_ids, input_image_ids, output_image_ids) + + elif task == "t2s_gen": + text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len) + audio_ids = input[1] # (B, #tokens) + sequence_ids_with_masks = self.t2s_gen_prompt(text_ids, audio_ids) + + elif task == "t2s_fixed_gen": + text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len) + audio_ids = input[1] # (B, #tokens) + sequence_ids_with_masks = self.t2s_fixed_gen_prompt(text_ids, audio_ids) + + elif task == "s2s_gen": + audio_user_ids = input[0] + audio_placeholders = input[1] + image_ids = input[2] if len(input) > 2 else None + sequence_ids_with_masks = self.s2s_gen_prompt(audio_user_ids, audio_placeholders, image_ids) + + elif task == "v2s_gen": + video_ids = input[0] + text_ids = self.text_tokenizer( + input[1], + truncation=True, + max_length=self.max_text_len, + )['input_ids'] + audio_ids = input[2] + sequence_ids_with_masks = self.v2s_gen_prompt(video_ids, text_ids, audio_ids) + + else: + raise NotImplementedError + + return sequence_ids_with_masks + + +if __name__ == '__main__': + pass diff --git a/MMaDA/training/train_mmada.py b/MMaDA/training/train_mmada.py new file mode 100644 index 0000000000000000000000000000000000000000..f039048dc4e4b0c1e2be2bf82fd94a64691e4aa9 --- /dev/null +++ b/MMaDA/training/train_mmada.py @@ -0,0 +1,983 @@ +# Copyright 2025 MMaDA Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import shutil +import time +from pathlib import Path +from typing import Union + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + return VQ_16 + else: + raise ValueError(f"model_type {model_type} not supported.") + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu) + total_batch_size = ( + (config.training.batch_size_t2i + config.training.batch_size_lm + config.training.batch_size_mmu) + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>" + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + if config.model.vq_model.get("pretrained_model_path", None): + vq_model = vq_model().to(accelerator.device) + state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] + vq_model.load_state_dict(state_dict) + else: + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + # Initialize mmada in pretraining stage + base_config = AutoConfig.from_pretrained(config.model.mmada.pretrained_model_path).to_dict() + mmada_config_dict = {k: v for k, v in config.model.mmada.items()} + merged_config = {**base_config, **mmada_config_dict} + mmada_config = MMadaConfig(**merged_config) + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16, config=mmada_config) + model.resize_token_embeddings(mmada_config.new_vocab_size) + model.config.embedding_size = model.config.vocab_size + model = model.to(accelerator.device) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size_t2i_without_accum = config.training.batch_size_t2i * accelerator.num_processes + total_batch_size_t2i = ( + config.training.batch_size_t2i * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Data for generation + if config.dataset.gen_type == "t2i": + dataset = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + tokenizer=None, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_t2i, + per_gpu_batch_size=config.training.batch_size_t2i, + global_batch_size=total_batch_size_t2i_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + ) + train_dataloader_t2i = dataset.train_dataloader + num_update_steps_per_epoch = math.ceil( + train_dataloader_t2i.num_batches / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + elif config.dataset.gen_type == "t2i_parquet": + # this part relies on the internal packages, which will not be released + num_update_steps_per_epoch = math.ceil(config.experiment.max_train_examples_t2i / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + train_dataloader_t2i = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + batch_size=config.training.batch_size_t2i, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size + ) + + elif config.dataset.gen_type == "imagenet1k": + dataset_imagenet = ImageNetDataset( + dataset_config.train_t2i_shards_path_or_url, + image_size=preproc_config.resolution, + ) + + print('process index : ', + accelerator.process_index, ', ', accelerator.num_processes, + "Length: ", len(dataset_imagenet)) + + if accelerator.num_processes > 1: + sampler = DistributedSampler(dataset_imagenet, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + ) + shuffle = False + else: + sampler = None + shuffle = True + + train_dataloader_t2i = DataLoader(dataset_imagenet, batch_size=config.training.batch_size_t2i, + sampler=sampler, collate_fn=dataset_imagenet.collate_fn, + shuffle=shuffle, num_workers=dataset_config.num_workers) + num_update_steps_per_epoch = math.ceil(len(dataset_imagenet) / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + else: + raise ValueError(f"Unsupported dataset type {config.dataset.type}") + + total_batch_size_mmu_without_accum = config.training.batch_size_mmu * accelerator.num_processes + # Data for image captioning + if config.dataset.und_type == "captioning": + dataset_mmu = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + tokenizer=None, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_mmu, + per_gpu_batch_size=config.training.batch_size_mmu, + global_batch_size=total_batch_size_mmu_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + is_captioning=True, + add_caption_prompt=dataset_config.add_caption_prompt, + ) + train_dataloader_mmu = dataset_mmu.train_dataloader + + elif config.dataset.und_type == "captioning_parquet": + train_dataloader_mmu = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + batch_size=config.training.batch_size_mmu, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + is_captioning=True + ) + + else: + raise NotImplementedError(f"Unsupported dataset type {config.dataset.und_type}") + + # LLM pure text dataset: RefinedWeb + dataset_lm = RefinedWebDataset(data_path=dataset_config.train_lm_shards_path_or_url, + rank=accelerator.process_index, + world_size=accelerator.num_processes, + num_workers=dataset_config.num_workers) + + train_dataloader_lm = torch.utils.data.DataLoader(dataset_lm, batch_size=config.training.batch_size_lm, + sampler=None, collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers) + + # Combine these dataloaders into a single iterable model + iterables = { + "t2i_flow": train_dataloader_t2i, + "lm_flow": train_dataloader_lm, + "mmu_flow": train_dataloader_mmu, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + # weight_map=None, + # load_state_dict_fn="safetensors" + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + vq_model.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + ): + + image_tokens = vq_model.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in range(first_epoch, num_train_epochs): + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + # for loss calculation + batch_size_t2i = batch["t2i_flow"]["images"].shape[0] + batch_size_lm = len(batch["lm_flow"]["input_ids"]) + batch_size_mmu = batch["mmu_flow"]["images"].shape[0] + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for class-conditional/text-to-image generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + pixel_values, texts = batch["t2i_flow"]["images"], batch["t2i_flow"]["input_ids"] + pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + data_time_m.update(time.time() - end) + + # Encode images to image tokens, mask them and create input and labels + ( + input_ids, + labels, + mask_prob, + image_tokens_ori, + t2i_masks + ) = prepare_inputs_and_labels(pixel_values, texts, config.training.min_masking_rate) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for language modeling + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + max_seq_len = input_ids.shape[-1] + texts_lm = batch["lm_flow"]["input_ids"] + ( + input_ids_lm, + labels_lm, + p_mask_lm + ) = prepare_inputs_and_labels_for_text(texts_lm, max_seq_len) + input_ids = torch.cat((input_ids, input_ids_lm.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_lm.to(input_ids.device)), dim=0) + + + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for captioning/multimodal understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if "llava" in config.dataset.und_type: + pixel_values_mmu, input_ids_mmu, labels_mmu = (batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"],batch["mmu_flow"]["labels"]) + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to( + accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to( + accelerator.device), + image_tokens_mmu, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to( + accelerator.device), + input_ids_mmu, + ], dim=1).long() + + labels_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + torch.ones_like(image_tokens_mmu) * uni_prompting.ignore_id, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + labels_mmu.to(accelerator.device) + ], dim=1).long() + + else: + pixel_values_mmu, texts_mmu = batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"] + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu, prompt_masks, labels_mmu = uni_prompting((image_tokens_mmu, texts_mmu), 'mmu') + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks, labels_mmu) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + + input_ids = torch.cat((input_ids, input_ids_mmu.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_mmu.to(input_ids.device)), dim=0) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + answer_lengths=answer_lengths, + t2i_masks=t2i_masks + ) + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss_t2i = accelerator.gather(loss_t2i.repeat(config.training.batch_size_t2i)).mean() + avg_loss_lm = accelerator.gather(loss_lm.repeat(config.training.batch_size_lm)).mean() + avg_loss_mmu = accelerator.gather(loss_mmu.repeat(config.training.batch_size_mmu)).mean() + loss = config.training.t2i_coeff * loss_t2i + \ + config.training.lm_coeff * loss_lm + \ + config.training.mmu_coeff * loss_mmu + + avg_masking_rate = accelerator.gather(mask_prob.repeat(config.training.batch_size_t2i)).mean() + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "step_loss_t2i": avg_loss_t2i.item(), + "step_loss_mmu": avg_loss_mmu.item(), + "step_loss_lm": avg_loss_lm.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1) + + if ((global_step + 1) % config.experiment.generate_every == 0 or global_step == 0) and accelerator.is_main_process: + generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + ) + + visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step + 1, + input_ids, + image_tokens_ori, + batch["t2i_flow"]["images"], + texts, + logits, + accelerator + ) + + understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + ) + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, config, accelerator, global_step) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +@torch.no_grad() +def visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step, + input_ids, + image_tokens_ori, + ori_images, + texts, + logits, + accelerator +): + logger.info("Visualizing predictions...") + model.eval() + + recons_images = vq_model.decode_code(image_tokens_ori - len(uni_prompting.text_tokenizer)) + recons_images = torch.clamp((recons_images + 1.0) / 2.0, min=0.0, max=1.0) + recons_images *= 255.0 + recons_images = recons_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + + images = torch.clamp((ori_images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predictions = logits[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:, len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens: len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens + config.model.mmada.codebook_size] + + predictions = predictions.argmax(axis=-1) + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id - len(uni_prompting.text_tokenizer) + input_ids = input_ids[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:] - len(uni_prompting.text_tokenizer) + mask_ratio = list((torch.where(input_ids == mask_token_id, 1, 0).sum( + dim=-1) / config.model.mmada.num_vq_tokens).cpu().numpy()) + predicted_images = torch.where(input_ids == mask_token_id, predictions, input_ids) + predicted_images = vq_model.decode_code(predicted_images) + predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0) + predicted_images *= 255.0 + predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predicted_images = np.concatenate((images, recons_images, predicted_images), 2) + pil_images = [Image.fromarray(image) for image in predicted_images] + + # Log images + wandb_images = [wandb.Image(image, caption=f'mask ratio: {r:0.2f} \n caption: {texts[i]}') for i, (image, r) in + enumerate(zip(pil_images, mask_ratio))] + wandb.log({"Original images v.s. Reconstructed images v.s. Predicted images": wandb_images}, step=global_step) + + model.train() + + +@torch.no_grad() +def generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, +): + logger.info("Generating images...") + model.eval() + + # read validation prompts from file + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + + mask_dtype = model.get_input_embeddings().weight.dtype + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + image_tokens = torch.ones((len(validation_prompts), config.model.mmada.num_vq_tokens), dtype=torch.long, + device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') + else: + uncond_input_ids = None + uncond_attention_mask = None + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + # Generate images + gen_token_ids = accelerator.unwrap_model(model).t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=config.training.guidance_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + + model.train() + + if config.training.get("pre_encode", False): + del vq_model + + # Convert to PIL images + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=validation_prompts[i]) for i, image in enumerate(pil_images)] + wandb.log({"Generated images": wandb_images}, step=global_step) + + + +@torch.no_grad() +def understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, +): + logger.info("Understanding images...") + model.eval() + + file_list = os.listdir(config.dataset.params.mmu_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + responses = ['' for i in range(len(file_list))] + images = [] + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + for i, file_name in enumerate(file_list): + image_path = os.path.join(config.dataset.params.mmu_image_root, file_name) + image_ori = Image.open(image_path).convert("RGB") + image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) + image = image.unsqueeze(0) + images.append(image) + image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + batch_size = 1 + + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." +'<|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + image_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate(input_ids) + # output_ids = torch.stack(output_ids).squeeze()[None] + + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + model.train() + images = torch.cat(images, dim=0) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] + wandb.log({"Understanding images": wandb_images}, step=global_step) + + +def save_checkpoint(model, config, accelerator, global_step): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_mmada_cot_sft.py b/MMaDA/training/train_mmada_cot_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e01aa2658d7d355a03fb4e27bb0584b35e7561 --- /dev/null +++ b/MMaDA/training/train_mmada_cot_sft.py @@ -0,0 +1,1234 @@ +# coding=utf-8 +# Copyright 2025 MMaDA Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import html +from pathlib import Path +from typing import Union + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform, image_transform_squash +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter +from torchmetrics.functional.multimodal import clip_score +from functools import partial +import ImageReward as RM +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + return VQ_16 + else: + raise ValueError(f"model_type {model_type} not supported.") + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu) + total_batch_size = ( + (config.training.batch_size_t2i + config.training.batch_size_lm + config.training.batch_size_mmu) + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>" + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + if config.model.vq_model.get("pretrained_model_path", None): + vq_model = vq_model().to(accelerator.device) + state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] + vq_model.load_state_dict(state_dict) + else: + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size_t2i_without_accum = config.training.batch_size_t2i * accelerator.num_processes + total_batch_size_t2i = ( + config.training.batch_size_t2i * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Data for generation + if config.dataset.gen_type == "t2i": + dataset = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + tokenizer=uni_prompting.text_tokenizer, # we want to get raw texts, tokenizer is just for length counting + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_t2i, + per_gpu_batch_size=config.training.batch_size_t2i, + global_batch_size=total_batch_size_t2i_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + external_text_to_image_2M_512_caption_path=dataset_config.external_text_to_image_2M_512_caption_path, + ) + train_dataloader_t2i = dataset.train_dataloader + num_update_steps_per_epoch = math.ceil( + train_dataloader_t2i.num_batches / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + elif config.dataset.gen_type == "t2i_parquet": + # this part relies on the internal packages, which will not be released + num_update_steps_per_epoch = math.ceil(config.experiment.max_train_examples_t2i / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + train_dataloader_t2i = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + batch_size=config.training.batch_size_t2i, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size + ) + + elif config.dataset.gen_type == "imagenet1k": + dataset_imagenet = ImageNetDataset( + dataset_config.train_t2i_shards_path_or_url, + image_size=preproc_config.resolution, + ) + + print('process index : ', + accelerator.process_index, ', ', accelerator.num_processes, + "Length: ", len(dataset_imagenet)) + + if accelerator.num_processes > 1: + sampler = DistributedSampler(dataset_imagenet, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + ) + shuffle = False + else: + sampler = None + shuffle = True + + train_dataloader_t2i = DataLoader(dataset_imagenet, batch_size=config.training.batch_size_t2i, + sampler=sampler, collate_fn=dataset_imagenet.collate_fn, + shuffle=shuffle, num_workers=dataset_config.num_workers) + num_update_steps_per_epoch = math.ceil(len(dataset_imagenet) / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + else: + raise ValueError(f"Unsupported dataset type {config.dataset.type}") + + + total_batch_size_mmu_without_accum = config.training.batch_size_mmu * accelerator.num_processes + # Data for image captioning + if config.dataset.und_type == "captioning": + dataset_mmu = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + tokenizer=uni_prompting.text_tokenizer, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_mmu, + per_gpu_batch_size=config.training.batch_size_mmu, + global_batch_size=total_batch_size_mmu_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + external_text_to_image_2M_512_caption_path=dataset_config.external_text_to_image_2M_512_caption_path, + external_ai2d_caption_path=dataset_config.external_ai2d_caption_path, + external_clevr_caption_path=dataset_config.external_clevr_caption_path, + external_docvqa_caption_path=dataset_config.external_docvqa_caption_path, + external_geo_caption_path=dataset_config.external_geo_caption_path, + is_captioning=True, + add_caption_prompt=dataset_config.add_caption_prompt, + ) + train_dataloader_mmu = dataset_mmu.train_dataloader + + elif config.dataset.und_type == "captioning_parquet": + train_dataloader_mmu = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + batch_size=config.training.batch_size_mmu, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + is_captioning=True + ) + + else: + raise NotImplementedError(f"Unsupported dataset type {config.dataset.und_type}") + + + dataset_lm = ChatDataset(data_path=dataset_config.train_lm_shards_path_or_url, + rank=accelerator.process_index, + world_size=accelerator.num_processes, + num_workers=dataset_config.num_workers, + max_length=preproc_config.max_lm_text_length, + tokenizer=uni_prompting.text_tokenizer, + ) + + train_dataloader_lm = torch.utils.data.DataLoader(dataset_lm, batch_size=config.training.batch_size_lm, + sampler=None, collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers) + + # Combine these dataloaders into a single iterable model + iterables = { + "t2i_flow": train_dataloader_t2i, + "lm_flow": train_dataloader_lm, + "mmu_flow": train_dataloader_mmu, + } + + # + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + vq_model.to(device=accelerator.device) + + + mask_dtype = model.get_input_embeddings().weight.dtype + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, list[str]], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + seed=seed + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + @torch.no_grad() + def prepare_inputs_and_labels_for_chat_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm_chat') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_mask.bool()] = input_ids_lm[prompt_mask.bool()] + masked_indices = noisy_batch == mask_id + answer_lengths_lm = torch.sum((1 - prompt_mask), dim=-1, keepdim=True) + answer_lengths_lm = answer_lengths_lm.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_lm, p_mask, answer_lengths_lm + + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in range(first_epoch, num_train_epochs): + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + + # for loss calculation + batch_size_t2i = batch["t2i_flow"]["images"].shape[0] + batch_size_lm = len(batch["lm_flow"]["input_ids"]) + batch_size_mmu = batch["mmu_flow"]["images"].shape[0] + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for class-conditional/text-to-image generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + pixel_values, texts = batch["t2i_flow"]["images"], batch["t2i_flow"]["input_ids"] + pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + data_time_m.update(time.time() - end) + # print(f"t2i texts: {texts}") + + # Encode images to image tokens, mask them and create input and labels + ( + input_ids, + labels, + mask_prob, + image_tokens_ori, + t2i_masks + ) = prepare_inputs_and_labels(pixel_values, texts, config.training.min_masking_rate) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for language modeling + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + max_seq_len = input_ids.shape[-1] + texts_lm = batch["lm_flow"]["input_ids"] + ( + input_ids_lm, + labels_lm, + p_mask_lm, + answer_lengths_lm + ) = prepare_inputs_and_labels_for_chat_text(texts_lm, max_seq_len) + input_ids = torch.cat((input_ids, input_ids_lm.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_lm.to(input_ids.device)), dim=0) + + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for captioning/multimodal understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if "llava" in config.dataset.und_type: + pixel_values_mmu, input_ids_mmu, labels_mmu = (batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"],batch["mmu_flow"]["labels"]) + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to( + accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to( + accelerator.device), + image_tokens_mmu, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to( + accelerator.device), + input_ids_mmu, + ], dim=1).long() + + labels_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + torch.ones_like(image_tokens_mmu) * uni_prompting.ignore_id, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + labels_mmu.to(accelerator.device) + ], dim=1).long() + + else: + + pixel_values_mmu, texts_mmu = batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"] + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu, prompt_masks, labels_mmu = uni_prompting((image_tokens_mmu, texts_mmu), 'mmu') + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks, labels_mmu) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + + + input_ids = torch.cat((input_ids, input_ids_mmu.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_mmu.to(input_ids.device)), dim=0) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + answer_lengths=answer_lengths, + t2i_masks=t2i_masks, + answer_lengths_lm=answer_lengths_lm + ) + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss_t2i = accelerator.gather(loss_t2i.repeat(config.training.batch_size_t2i)).mean() + avg_loss_lm = accelerator.gather(loss_lm.repeat(config.training.batch_size_lm)).mean() + avg_loss_mmu = accelerator.gather(loss_mmu.repeat(config.training.batch_size_mmu)).mean() + loss = config.training.t2i_coeff * loss_t2i + \ + config.training.lm_coeff * loss_lm + \ + config.training.mmu_coeff * loss_mmu + + avg_masking_rate = accelerator.gather(mask_prob.repeat(config.training.batch_size_t2i)).mean() + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "step_loss_t2i": avg_loss_t2i.item(), + "step_loss_mmu": avg_loss_mmu.item(), + "step_loss_lm": avg_loss_lm.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + if ((global_step + 1) % config.experiment.generate_every == 0 or global_step == start_step) and accelerator.is_main_process: + quantative_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + force_no_cfg=False + ) + + generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + force_no_cfg=False + ) + + visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step + 1, + input_ids, + image_tokens_ori, + batch["t2i_flow"]["images"], + texts, + logits, + accelerator + ) + + understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + ) + + generate_chat_text( + model, + uni_prompting, + accelerator, + config, + global_step + 1, + ) + + global_step += 1 + # Stop training if max steps is reached + if global_step >= config.training.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +@torch.no_grad() +def visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step, + input_ids, + image_tokens_ori, + ori_images, + texts, + logits, + accelerator +): + logger.info("Visualizing predictions...") + model.eval() + + recons_images = vq_model.decode_code(image_tokens_ori - len(uni_prompting.text_tokenizer)) + recons_images = torch.clamp((recons_images + 1.0) / 2.0, min=0.0, max=1.0) + recons_images *= 255.0 + recons_images = recons_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + + images = torch.clamp((ori_images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predictions = logits[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:, len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens: len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens + config.model.mmada.codebook_size] + predictions = predictions.argmax(axis=-1) + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id - len(uni_prompting.text_tokenizer) + input_ids = input_ids[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:] - len(uni_prompting.text_tokenizer) + mask_ratio = list((torch.where(input_ids == mask_token_id, 1, 0).sum( + dim=-1) / config.model.mmada.num_vq_tokens).cpu().numpy()) + predicted_images = torch.where(input_ids == mask_token_id, predictions, input_ids) + predicted_images = vq_model.decode_code(predicted_images) + predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0) + predicted_images *= 255.0 + predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predicted_images = np.concatenate((images, recons_images, predicted_images), 2) + pil_images = [Image.fromarray(image) for image in predicted_images] + + # Log images + wandb_images = [wandb.Image(image, caption=f'mask ratio: {r:0.2f} \n caption: {texts[i]}') for i, (image, r) in + enumerate(zip(pil_images, mask_ratio))] + wandb.log({"Original images v.s. Reconstructed images v.s. Predicted images": wandb_images}, step=global_step) + + model.train() + + +@torch.no_grad() +def generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, + force_no_cfg = False +): + logger.info("Generating images...") + model.eval() + + # read validation prompts from file + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + mask_dtype = model.get_input_embeddings().weight.dtype + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + image_tokens = torch.ones((len(validation_prompts), config.model.mmada.num_vq_tokens), dtype=torch.long, + device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') + if not force_no_cfg and config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids = None + uncond_attention_mask = None + cfg_scale = 0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + # Generate images + gen_token_ids = accelerator.unwrap_model(model).t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + + model.train() + + if config.training.get("pre_encode", False): + del vq_model + + # Convert to PIL images + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=validation_prompts[i]) for i, image in enumerate(pil_images)] + wandb.log({f"Generated images with cfg {cfg_scale}": wandb_images}, step=global_step) + + + + + +@torch.no_grad() +def quantative_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, + force_no_cfg = False +): + logger.info("Quantative images...") + model.eval() + clip_score_fn = partial(clip_score, model_name_or_path="/data_storage/shared/pretrained_models/") + image_reward_model = RM.load("/data_storage/shared/pretrained_models/ImageReward/ImageReward.pt") + # read validation prompts from file + with open(config.validation.quantative_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + mask_dtype = model.get_input_embeddings().weight.dtype + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + image_tokens = torch.ones((len(validation_prompts), config.model.mmada.num_vq_tokens), dtype=torch.long, + device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') + if not force_no_cfg and config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids = None + uncond_attention_mask = None + cfg_scale = 0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + validation_batch_size = config.validation.quantative_batch_size + + pil_images = [] + clip_scores = [] + image_rewards = [] + for i in range(0, len(validation_prompts), validation_batch_size): + batch_input_ids = input_ids[i:i+validation_batch_size] + batch_attention_mask = attention_mask[i:i+validation_batch_size] + batch_uncond_input_ids = uncond_input_ids[i:i+validation_batch_size] + batch_uncond_attention_mask = uncond_attention_mask[i:i+validation_batch_size] + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + # Generate images + gen_token_ids = accelerator.unwrap_model(model).t2i_generate( + input_ids=batch_input_ids, + uncond_input_ids=batch_uncond_input_ids, + attention_mask=batch_attention_mask, + uncond_attention_mask=batch_uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + image_tensor = images.to(torch.uint8) + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + batch_pil_images = [Image.fromarray(image) for image in images] + pil_images.extend(batch_pil_images) + + # calculate CLIP score + batch_clip_score = clip_score_fn(image_tensor, validation_prompts[i:i+validation_batch_size]) + # calculate image reward score + for j in range(validation_batch_size): + clip_scores.append(clip_score_fn(image_tensor[j], validation_prompts[i+j])) + image_reward_score = image_reward_model.score(validation_prompts[i+j], batch_pil_images[j]) + image_rewards.append(image_reward_score) + clip_scores = torch.tensor(clip_scores) + image_rewards = torch.tensor(image_rewards) + logger.info(f"clip_scores: {clip_scores}, image_rewards: {image_rewards}") + clip_scores_mean = clip_scores.mean() + image_rewards_mean = image_rewards.mean() + logger.info(f"CLIP score mean: {clip_scores_mean}, Image reward score mean: {image_rewards_mean}") + accelerator.log({"clip_score": clip_scores_mean, "image_reward_score": image_rewards_mean}, step=global_step) + wandb_images = [wandb.Image(image, caption=f"{validation_prompts[i]} \n CLIP score: {clip_scores[i]}, Image reward score: {image_rewards[i]}") for i, image in enumerate(pil_images[:validation_batch_size])] + wandb.log({f"Quantative images with cfg {cfg_scale}": wandb_images}, step=global_step) + + + if config.training.get("pre_encode", False): + del vq_model + + model.train() + + + + + +@torch.no_grad() +def understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, +): + logger.info("Understanding images...") + model.eval() + + prompts_file_path = config.dataset.params.mmu_validation_prompts_file + prompts_dict = {} + try: + with open(prompts_file_path, 'r') as f: + for line in f: + data = json.loads(line) + prompts_dict[data['file_name']] = data['prompt'] + except Exception as e: + logger.error(f"Error loading prompts from {prompts_file_path}: {e}. Using default prompt.") + default_prompt = '<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." + '<|start_header_id|>assistant<|end_header_id|>\n' + + file_list = os.listdir(config.dataset.params.mmu_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + file_list = sorted(file_list) + responses = ['' for i in range(len(file_list))] + questions = ['' for i in range(len(file_list))] + images = [] + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + for i, file_name in enumerate(file_list): + image_path = os.path.join(config.dataset.params.mmu_image_root, file_name) + image_ori = Image.open(image_path).convert("RGB") + if 'ai2d' in file_name or 'clevr' in file_name or 'docvqa' in file_name or 'geo' in file_name: + image = image_transform_squash(image_ori, resolution=config.dataset.params.resolution).to(device) + else: + image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) + image = image.unsqueeze(0) + images.append(image) + image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + batch_size = 1 + + current_prompt = prompts_dict.get(file_name) + if current_prompt is None: + logger.warning(f"Prompt for {file_name} not found in {prompts_file_path}. Using default prompt.") + default_prompt_for_missing = '<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." + '<|start_header_id|>assistant<|end_header_id|>\n' + current_prompt = default_prompt_for_missing if prompts_dict else default_prompt # å¦‚ęžœ prompts_dict äøŗē©ŗļ¼ˆåŠ č½½å¤±č“„ļ¼‰ļ¼Œåˆ™ä½æē”ØåŠ č½½å¤±č“„ę—¶ēš„é»˜č®¤å€¼ + input_ids = uni_prompting.text_tokenizer([current_prompt])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + image_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate(input_ids, max_new_tokens=config.dataset.preprocessing.max_seq_length, steps=config.dataset.preprocessing.max_seq_length // 2, block_length=config.dataset.preprocessing.max_seq_length // 4) + + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + current_prompt = current_prompt.removeprefix("<|start_header_id|>user<|end_header_id|>\n").removesuffix("<|start_header_id|>assistant<|end_header_id|>\n") + questions[i] += current_prompt + responses[i] += text[0] + model.train() + images = torch.cat(images, dim=0) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [ + wandb.Image( + image, + caption=f"**Question:** {questions[i]}\n**Response:** {responses[i]}" + ) + for i, image in enumerate(pil_images) + ] + wandb.log({"Understanding images": wandb_images}, step=global_step) + +@torch.no_grad() +def generate_chat_text( + model, + uni_prompting, + accelerator, + config, + global_step, +): + logger.info("Generating chat text...") + model.eval() + + df = pandas.read_json(config.dataset.params.lm_chat_validation_jsonl, lines=True) + prompts = df['question'].tolist() + responses = [''] * len(prompts) + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + html_content = "
    " + html_content += f"

    Step {global_step}

    " + + for i, prompt in enumerate(prompts): + original_prompt = prompt + + prompt_with_tags = "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt}" + "<|start_header_id|>assistant<|end_header_id|>\n" + token_ids = uni_prompting.text_tokenizer([prompt_with_tags])['input_ids'][0] + token_ids = [uni_prompting.text_tokenizer.bos_token_id] + token_ids + input_ids = torch.tensor(token_ids).unsqueeze(0).to(device) + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate( + input_ids, + max_new_tokens=config.dataset.preprocessing.max_seq_length, + steps=config.dataset.preprocessing.max_lm_text_length // 2, + block_length=config.dataset.preprocessing.max_seq_length // 4 + ) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + + escaped_prompt = html.escape(original_prompt) + escaped_response = html.escape(responses[i]) + html_content += f""" +
    +

    Prompt

    +

    {escaped_prompt}

    +

    Response

    +

    {escaped_response}

    +
    + """ + + html_content += "
    " + + model.train() + + wandb.log({"chat_text_generation": wandb.Html(html_content)}, step=global_step) + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_mmada_i2i.py b/MMaDA/training/train_mmada_i2i.py new file mode 100644 index 0000000000000000000000000000000000000000..464006a42e263b19cbea1f45b463e9e33a81f989 --- /dev/null +++ b/MMaDA/training/train_mmada_i2i.py @@ -0,0 +1,840 @@ +# Copyright 2025 AIDAS Lab +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import shutil +import time +from pathlib import Path +from typing import Union, List + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + return VQ_16 + else: + raise ValueError(f"model_type {model_type} not supported.") + +# +++++ I2I Dataset Definition +++++ +class InstructPix2PixDataset(Dataset): + def __init__(self, hf_dataset, resolution): + self.dataset = hf_dataset + self.resolution = resolution + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + sample = self.dataset[idx] + original_image = sample['original_image'].convert("RGB") + edit_prompt = sample['edit_prompt'] + edited_image = sample['edited_image'].convert("RGB") + + return { + "original_image": image_transform(original_image, self.resolution), + "edit_prompt": edit_prompt, + "edited_image": image_transform(edited_image, self.resolution) + } + +def collate_fn(batch): + original_images = torch.stack([item['original_image'] for item in batch]) + edit_prompts = [item['edit_prompt'] for item in batch] + edited_images = torch.stack([item['edited_image'] for item in batch]) + return { + 'original_images': original_images, + 'edit_prompts': edit_prompts, + 'edited_images': edited_images, + } +# ++++++++++++++++++++++++++++++++++ + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + ) + + total_batch_size_per_gpu = config.training.batch_size_i2i + total_batch_size = ( + config.training.batch_size_i2i + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + logger.info("="*50) + logger.info(f"max_train_steps from config: {config.training.max_train_steps}") + logger.info("="*50) + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|i2i|>", # Changed for I2I + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + if config.model.vq_model.get("pretrained_model_path", None): + vq_model = vq_model().to(accelerator.device) + state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] + vq_model.load_state_dict(state_dict) + else: + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps + # max_train_steps_for_scheduler = config.training.max_train_steps + + lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps * accelerator.num_processes + max_train_steps_for_scheduler = config.training.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps_for_scheduler, + num_training_steps=max_train_steps_for_scheduler, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloader and lr_scheduler for I2I task") + + # Load dataset from Hugging Face and split into train/eval + logger.info("Loading and splitting the InstructPix2Pix dataset...") + full_dataset = load_dataset("timbrooks/instructpix2pix-clip-filtered", split="train") + + # Split the dataset: 90% for training, 10% for evaluation. + # Use a fixed seed for reproducible splits. + split_dataset = full_dataset.train_test_split(test_size=0.1, seed=config.training.seed) + + train_split = split_dataset['train'] + eval_split = split_dataset['test'] + + logger.info(f"Dataset split into {len(train_split)} training samples and {len(eval_split)} evaluation samples.") + + train_dataset = InstructPix2PixDataset( + hf_dataset=train_split, + resolution=config.dataset.preprocessing.resolution + ) + + eval_dataset = InstructPix2PixDataset( + hf_dataset=eval_split, + resolution=config.dataset.preprocessing.resolution + ) + + sampler = DistributedSampler(train_dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True) if accelerator.num_processes > 1 else None + + train_dataloader = DataLoader( + train_dataset, + batch_size=config.training.batch_size_i2i, + shuffle=sampler is None, + sampler=sampler, + collate_fn=collate_fn, + num_workers=config.dataset.params.num_workers, + ) + + eval_dataloader = DataLoader( + eval_dataset, + batch_size=config.training.batch_size_i2i, + collate_fn=collate_fn, + num_workers=config.dataset.params.num_workers, + ) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + + # if config.experiment.resume_from_checkpoint: + # dirs = os.listdir(config.experiment.output_dir) + # logger.info(f"dirs: {dirs}") + # dirs = [d for d in dirs if d.startswith("checkpoint")] + # dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + # path = dirs[-1] if len(dirs) > 0 else None + # logger.info(f"path: {path}") + # if path is not None: + # path = os.path.join(config.experiment.output_dir, path) + # logger.info(f"Resuming from checkpoint: {path}") + # global_step = int(os.path.basename(path).split("-")[1]) + # first_epoch = global_step // num_update_steps_per_epoch + # if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + # state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + # model.load_state_dict(state_dict, strict=True) + # del state_dict + # elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + # from transformers.modeling_utils import load_sharded_checkpoint + # load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # # if safetensors sharded checkpoint exists + # elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + # from transformers.modeling_utils import load_sharded_checkpoint + # load_sharded_checkpoint( + # model, + # f'{path}/unwrapped_model/', + # ) + # else: + # raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + if config.experiment.resume_from_checkpoint: + output_dir = Path(config.experiment.output_dir) + if os.path.exists(output_dir): + checkpoint_dirs = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")] + if checkpoint_dirs: + latest_checkpoint = sorted(checkpoint_dirs, key=lambda x: int(x.split("-")[1]))[-1] + resume_path = os.path.join(output_dir, latest_checkpoint) + + logger.info(f"Resuming from checkpoint: {resume_path}") + + accelerator.load_state(resume_path) + + global_step = int(latest_checkpoint.split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + else: + logger.info("No checkpoint found to resume from, starting from scratch.") + else: + logger.info("Output directory does not exist, starting from scratch.") + else: + logger.info("Not resuming from checkpoint.") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloader") + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model.to(device=accelerator.device) + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num train examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels_for_i2i( + original_pixel_values: torch.FloatTensor, + edited_pixel_values: torch.FloatTensor, + prompts: List[str], + is_train: bool = True, + ): + # Tokenize both original and edited images + original_image_tokens = vq_model.get_code(original_pixel_values) + len(uni_prompting.text_tokenizer) + edited_image_tokens = vq_model.get_code(edited_pixel_values) + len(uni_prompting.text_tokenizer) + + if is_train and torch.rand(1).item() < config.training.cond_dropout_prob: + # Dropout O: + prompts = [''] * len(prompts) + target_image_tokens = original_image_tokens + else: + # Dropout X: + prompts = prompts + target_image_tokens = edited_image_tokens + + # Mask the target (edited) image tokens for prediction + masked_edited_tokens, labels, _, mask_prob = mask_or_random_replace_tokens( + target_image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + + # Create the full input sequence: [i2i] [soi] original_img [eoi] prompt [soi] masked_edited_img [eoi] + input_ids, masks, labels = uni_prompting( + (prompts, original_image_tokens, masked_edited_tokens, labels), 'i2i' + ) + + return input_ids, labels, mask_prob, edited_image_tokens, masks + + def run_evaluation(model, eval_dataloader, accelerator, global_step): + logger.info(f"Running evaluation at step {global_step}...") + model.eval() + eval_losses = [] + with torch.no_grad(): + for batch in tqdm(eval_dataloader, desc="Evaluating", disable=not accelerator.is_local_main_process): + original_images = batch["original_images"].to(accelerator.device, non_blocking=True) + edited_images = batch["edited_images"].to(accelerator.device, non_blocking=True) + edit_prompts = batch["edit_prompts"] + input_ids, labels, _, _, masks = prepare_inputs_and_labels_for_i2i( + original_images, edited_images, edit_prompts, is_train=False + ) + _, loss_i2i = accelerator.unwrap_model(model).forward_i2i( + input_ids=input_ids, attention_mask=masks, labels=labels + ) + gathered_losses = accelerator.gather(loss_i2i.repeat(original_images.shape[0])) + eval_losses.append(gathered_losses) + avg_loss = torch.cat(eval_losses).mean() + logger.info(f"Step {global_step}: Evaluation Loss = {avg_loss.item()}") + accelerator.log({"evaluation/eval_loss": avg_loss.item()}, step=global_step) + model.train() + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in range(first_epoch, num_train_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + data_time_m.update(time.time() - end) + + # Generate and log validation images + if ((global_step + 1) % config.experiment.generate_every == 0 or global_step == 0) and accelerator.is_main_process: + + generate_i2i_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + val_dataset=eval_split + ) + if accelerator.is_main_process: + torch.cuda.empty_cache() + + original_images = batch["original_images"].to(accelerator.device, non_blocking=True) + edited_images = batch["edited_images"].to(accelerator.device, non_blocking=True) + edit_prompts = batch["edit_prompts"] + + batch_size = original_images.shape[0] + + ( + input_ids, + labels, + mask_prob, + gt_edited_tokens, + masks + ) = prepare_inputs_and_labels_for_i2i(original_images, edited_images, edit_prompts) + + if global_step == 0 and epoch == 0 and step == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + # with accelerator.accumulate(model): + logits, loss_i2i = model.forward_i2i( + input_ids=input_ids, + attention_mask=masks, + labels=labels + ) + + # Gather the losses across all processes for logging + avg_loss_i2i = accelerator.gather(loss_i2i.repeat(batch_size)).mean() + + accelerator.backward(loss_i2i) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + avg_masking_rate = accelerator.gather(mask_prob.repeat(batch_size)).mean() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * batch_size / batch_time_m.val + ) + logs = { + "step_loss_i2i": avg_loss_i2i.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_i2i: {avg_loss_i2i.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + # f"LR_temp: {optimizer.param_groups[0]['lr']}" + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + # save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + save_checkpoint( + accelerator, + config.experiment.output_dir, + global_step + 1, + uni_prompting, + config.experiment.get("checkpoints_total_limit") + ) + + if (global_step + 1) % config.experiment.eval_every == 0 : + run_evaluation(model, eval_dataloader, accelerator, global_step + 1) + if accelerator.is_main_process: + torch.cuda.empty_cache() + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + save_checkpoint( + accelerator, + config.experiment.output_dir, + global_step + 1, + uni_prompting, + config.experiment.get("checkpoints_total_limit") + ) + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +@torch.no_grad() +def generate_i2i_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, + val_dataset +): + logger.info("Generating I2I images for validation...") + model.eval() + + # Load a few validation samples + # = load_dataset("timbrooks/instructpix2pix-clip-filtered", split="train") + torch.cuda.empty_cache() + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + val_samples = val_dataset.select(range(config.experiment.get("num_validation_images", 4))) + + original_images_pil = [sample['original_image'].convert("RGB") for sample in val_samples] + prompts = [sample['edit_prompt'] for sample in val_samples] + edited_images_pil = [sample['edited_image'].convert("RGB") for sample in val_samples] + + original_images_list = [image_transform(img,resolution=config.dataset.params.resolution).unsqueeze(0) for img in original_images_pil] + + original_images = torch.cat(original_images_list, dim=0).to(accelerator.device) + original_image_tokens = vq_model.get_code(original_images) + len(uni_prompting.text_tokenizer) + + output_image_placeholder = torch.ones((len(prompts), config.model.mmada.num_vq_tokens), + dtype=torch.long, device=accelerator.device) * mask_token_id + + input_ids, attention_mask = uni_prompting( + (prompts, original_image_tokens, output_image_placeholder), 'i2i_gen' + ) + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(prompts), original_image_tokens, + output_image_placeholder), 'i2i_gen') + else: + uncond_input_ids = None + uncond_attention_mask = None + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + gen_token_ids = model.i2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=config.training.guidance_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + generated_images = vq_model.decode_code(gen_token_ids) + model.train() + + # Convert to PIL images for logging + generated_images = torch.clamp((generated_images + 1.0) / 2.0, min=0.0, max=1.0) + generated_images *= 255.0 + generated_images = generated_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + gen_images_pil = [Image.fromarray(image) for image in generated_images] + + # Log images to wandb + wandb_images = [] + log_resolution = (512, 512) + for i in range(len(prompts)): + source_resized = original_images_pil[i].resize(log_resolution, Image.Resampling.LANCZOS) + generated_resized = gen_images_pil[i].resize(log_resolution, Image.Resampling.LANCZOS) + gt_resized = edited_images_pil[i].resize(log_resolution, Image.Resampling.LANCZOS) + + w, h = log_resolution + composite_image = Image.new('RGB', (w * 3, h)) + + composite_image.paste(source_resized, (0, 0)) + composite_image.paste(generated_resized, (w, 0)) + composite_image.paste(gt_resized, (w * 2, 0)) + + caption = f"Prompt: {prompts[i]}" + wandb_images.append(wandb.Image(composite_image, caption=caption)) + + wandb.log({"Generated I2I Examples": wandb_images}, step=global_step) + logger.info("Generating Validation Images DONE!!") + +@torch.no_grad() +def visualize_i2i_predictions( + model, + vq_model, + uni_prompting, + config, + global_step, + batch, + accelerator, + num_images_to_log=4, +): + logger.info("Visualizing predictions on a training batch...") + model.eval() + + # Log only a few images from the batch to prevent clutter + prompts = batch["edit_prompts"][:num_images_to_log] + original_images = batch["original_images"][:num_images_to_log] + gt_edited_images = batch["edited_images"][:num_images_to_log] + + # Tokenize the source images + original_image_tokens = vq_model.get_code(original_images) + len(uni_prompting.text_tokenizer) + + # Determine the data type for generation + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + # Generate images from the source image and prompt + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + gen_token_ids = accelerator.unwrap_model(model)( + original_image_tokens=original_image_tokens, + prompts=prompts, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=get_mask_schedule(config.training.get("mask_schedule", "cosine")), + uni_prompting=uni_prompting, + config=config, + ) + + # Decode the generated tokens into images + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + generated_images = vq_model.decode_code(gen_token_ids) + + # Convert all tensors to PIL images for visualization + def to_pil(tensor_images): + tensor_images = torch.clamp((tensor_images + 1.0) / 2.0, min=0.0, max=1.0) + tensor_images *= 255.0 + tensor_images = tensor_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + return [Image.fromarray(image) for image in tensor_images] + + source_pils = to_pil(original_images) + generated_pils = to_pil(generated_images) + gt_edited_pils = to_pil(gt_edited_images) + + # Log images to wandb + wandb_images = [] + for i in range(num_images_to_log): + # Create a composite image: [Source | Generated | Ground Truth] + w, h = source_pils[i].size + composite_image = Image.new('RGB', (w * 3, h)) + composite_image.paste(source_pils[i], (0, 0)) + composite_image.paste(generated_pils[i], (w, 0)) + composite_image.paste(gt_edited_pils[i], (w * 2, 0)) + + caption = f"Prompt: {prompts[i]}\n(Left: Source, Middle: Generated, Right: Ground Truth)" + wandb_images.append(wandb.Image(composite_image, caption=caption)) + + wandb.log({"Training Batch Predictions": wandb_images}, step=global_step) + + model.train() # Set model back to training mode + +# def save_checkpoint(model, config, accelerator, global_step, uni_prompting): +# output_dir = config.experiment.output_dir +# checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + +# # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` +# if accelerator.is_main_process and checkpoints_total_limit is not None: +# checkpoints = os.listdir(output_dir) +# checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] +# checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + +# if len(checkpoints) >= checkpoints_total_limit: +# num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 +# removing_checkpoints = checkpoints[0:num_to_remove] +# logger.info(f"Removing {len(removing_checkpoints)} checkpoints to keep limit of {checkpoints_total_limit}") +# for removing_checkpoint in removing_checkpoints: +# shutil.rmtree(os.path.join(output_dir, removing_checkpoint)) + +# save_path = Path(output_dir) / f"checkpoint-{global_step}" + +# state_dict = accelerator.get_state_dict(model) +# if accelerator.is_main_process: +# unwrapped_model = accelerator.unwrap_model(model) +# unwrapped_model.save_pretrained( +# save_path / "unwrapped_model", +# save_function=accelerator.save, +# state_dict=state_dict, +# safe_serialization=True +# ) +# json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) +# logger.info(f"Saved state to {save_path}") +# uni_prompting.text_tokenizer.save_pretrained(save_path / "unwrapped_model") + +def save_checkpoint(accelerator, output_dir, global_step, uni_prompting, checkpoints_total_limit=None): + + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint-")] + + if len(checkpoints) >= checkpoints_total_limit: + checkpoints_to_remove = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[:len(checkpoints) - checkpoints_total_limit + 1] + logger.info(f"Removing {len(checkpoints_to_remove)} old checkpoints to keep a total of {checkpoints_total_limit}") + for ckpt in checkpoints_to_remove: + shutil.rmtree(os.path.join(output_dir, ckpt)) + + save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + + accelerator.save_state(save_path) + logger.info(f"Saved complete state to {save_path}") + + if accelerator.is_main_process: + uni_prompting.text_tokenizer.save_pretrained(save_path) + +def log_grad_norm(model, accelerator, global_step): + if not accelerator.is_main_process: + return + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MMaDA/training/train_mmada_s2t.py b/MMaDA/training/train_mmada_s2t.py new file mode 100644 index 0000000000000000000000000000000000000000..cc954b61ec4f74b1ebfb1ac45a55e62813facbe1 --- /dev/null +++ b/MMaDA/training/train_mmada_s2t.py @@ -0,0 +1,568 @@ +# Copyright 2025 AIDAS Lab +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.filterwarnings("ignore") + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import shutil +import time +from pathlib import Path +from typing import Union, List + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ S2T-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset +# import librosa +# ++++++++++++++++++++++++++++++ + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.mmada.new_vocab_size}") + model.resize_token_embeddings(config.model.mmada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } +# ++++++++++++++++++++++++++++++++++ + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + ) + + total_batch_size_per_gpu = config.training.batch_size_s2t + total_batch_size = ( + config.training.batch_size_s2t + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + logger.info("="*50) + logger.info(f"max_train_steps from config: {config.training.max_train_steps}") + logger.info("="*50) + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + logger.info('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.mmada.new_vocab_size}") + + if resized_vocab_size == config.model.mmada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps + max_train_steps_for_scheduler = config.training.max_train_steps + + # lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps * accelerator.num_processes + # max_train_steps_for_scheduler = config.training.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps_for_scheduler, + num_training_steps=max_train_steps_for_scheduler, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloader and lr_scheduler for S2T task") + train_dataset = SpeechTextDataset(config.dataset.data.name, config.dataset.data.subset, config.dataset.data.split) + eval_dataset = train_dataset + + logger.info(f"Dataset Prepared.") + + sampler = DistributedSampler(train_dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True) if accelerator.num_processes > 1 else None + + train_dataloader = DataLoader(train_dataset, batch_size=config.training.batch_size_s2t, shuffle=True, collate_fn=collate_fn, num_workers=config.dataset.params.num_workers) + eval_dataloader = DataLoader(eval_dataset, batch_size=config.training.batch_size_s2t, collate_fn=collate_fn, num_workers=config.dataset.params.num_workers) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + + # if config.experiment.resume_from_checkpoint: + # dirs = os.listdir(config.experiment.output_dir) + # logger.info(f"dirs: {dirs}") + # dirs = [d for d in dirs if d.startswith("checkpoint")] + # dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + # path = dirs[-1] if len(dirs) > 0 else None + # logger.info(f"path: {path}") + # if path is not None: + # path = os.path.join(config.experiment.output_dir, path) + # logger.info(f"Resuming from checkpoint: {path}") + # global_step = int(os.path.basename(path).split("-")[1]) + # first_epoch = global_step // num_update_steps_per_epoch + # if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + # state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + # model.load_state_dict(state_dict, strict=True) + # del state_dict + # elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + # from transformers.modeling_utils import load_sharded_checkpoint + # load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # # if safetensors sharded checkpoint exists + # elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + # from transformers.modeling_utils import load_sharded_checkpoint + # load_sharded_checkpoint( + # model, + # f'{path}/unwrapped_model/', + # ) + # else: + # raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + if config.experiment.resume_from_checkpoint: + output_dir = Path(config.experiment.output_dir) + if os.path.exists(output_dir): + checkpoint_dirs = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")] + if checkpoint_dirs: + latest_checkpoint = sorted(checkpoint_dirs, key=lambda x: int(x.split("-")[1]))[-1] + resume_path = os.path.join(output_dir, latest_checkpoint) + + logger.info(f"Resuming from checkpoint: {resume_path}") + + accelerator.load_state(resume_path) + + global_step = int(latest_checkpoint.split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + else: + logger.info("No checkpoint found to resume from, starting from scratch.") + else: + logger.info("Output directory does not exist, starting from scratch.") + else: + logger.info("Not resuming from checkpoint.") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloader") + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model.to(device=accelerator.device) + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num train examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + + for epoch in range(first_epoch, num_train_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + data_time_m.update(time.time() - end) + + batch_size_s2t = len(batch["audio_path"]) + audio_paths, texts_s2t = batch["audio_path"], batch["text"] + # audio_values_s2t = audio_values_s2t.to(accelerator.device, non_blocking=True) + # audio_tokens_s2t = vq_model.encode(audio_values_s2t) + offset = len(uni_prompting.text_tokenizer) + config.model.mmada.codebook_size + + all_audio_tokens = [] + for path in audio_paths: + tokens = vq_model.encode(path) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 50256 + max_len = max(t.shape[1] for t in all_audio_tokens) + + padded_batch = [] + for tokens in all_audio_tokens: + num_padding = max_len - tokens.shape[1] + padding = torch.full((tokens.shape[0], num_padding), pad_token_id, dtype=tokens.dtype, device=tokens.device) + padded_tensor = torch.cat([tokens, padding], dim=1) + padded_batch.append(padded_tensor) + + audio_tokens_s2t = torch.cat(padded_batch, dim=0) + audio_tokens_s2t = audio_tokens_s2t.to(accelerator.device, non_blocking=True) + input_ids_s2t, prompt_masks, labels_s2t = uni_prompting((audio_tokens_s2t, texts_s2t), 's2t') + + input_ids_s2t, labels, p_mask_s2t, answer_lengths = prepare_inputs_and_labels_for_s2t( + input_ids_s2t, prompt_masks, labels_s2t + ) + + input_ids_s2t = input_ids_s2t.to(accelerator.device, non_blocking=True) + logits, s2t_loss = accelerator.unwrap_model(model).forward_s2t(input_ids = input_ids_s2t, + p_mask_s2t = p_mask_s2t, + labels= labels, + answer_lengths= answer_lengths, + batch_size_s2t = batch_size_s2t) + + # Gather the losses across all processes for logging + avg_loss_s2t = accelerator.gather(s2t_loss.repeat(batch_size_s2t)).mean() + + accelerator.backward(s2t_loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + avg_masking_rate = accelerator.gather(p_mask_s2t.mean()).mean() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * batch_size_s2t / batch_time_m.val + ) + logs = { + "step_loss_s2t": avg_loss_s2t.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_s2t: {avg_loss_s2t.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + # save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + save_checkpoint( + accelerator, + config.experiment.output_dir, + global_step + 1, + uni_prompting, + config.experiment.get("checkpoints_total_limit") + ) + + # if (global_step + 1) % config.experiment.eval_every == 0 : + # run_evaluation(model, eval_dataloader, accelerator, global_step + 1) + # if accelerator.is_main_process: + # torch.cuda.empty_cache() + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + save_checkpoint( + accelerator, + config.experiment.output_dir, + global_step + 1, + uni_prompting, + config.experiment.get("checkpoints_total_limit") + ) + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +def save_checkpoint(accelerator, output_dir, global_step, uni_prompting, checkpoints_total_limit=None): + + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint-")] + + if len(checkpoints) >= checkpoints_total_limit: + checkpoints_to_remove = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[:len(checkpoints) - checkpoints_total_limit + 1] + logger.info(f"Removing {len(checkpoints_to_remove)} old checkpoints to keep a total of {checkpoints_total_limit}") + for ckpt in checkpoints_to_remove: + shutil.rmtree(os.path.join(output_dir, ckpt)) + + save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + + accelerator.save_state(save_path) + logger.info(f"Saved complete state to {save_path}") + + if accelerator.is_main_process: + uni_prompting.text_tokenizer.save_pretrained(save_path) + +def log_grad_norm(model, accelerator, global_step): + if not accelerator.is_main_process: + return + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MMaDA/training/train_mmada_stage2.py b/MMaDA/training/train_mmada_stage2.py new file mode 100644 index 0000000000000000000000000000000000000000..b492b1ec8ea12375ec5e247b7b15deeca8ad8947 --- /dev/null +++ b/MMaDA/training/train_mmada_stage2.py @@ -0,0 +1,995 @@ +# Copyright 2025 MMaDA Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import shutil +import time +from pathlib import Path +from typing import Union + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + return VQ_16 + else: + raise ValueError(f"model_type {model_type} not supported.") + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu) + total_batch_size = ( + (config.training.batch_size_t2i + config.training.batch_size_lm + config.training.batch_size_mmu) + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>" + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + if config.model.vq_model.get("pretrained_model_path", None): + vq_model = vq_model().to(accelerator.device) + state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] + vq_model.load_state_dict(state_dict) + else: + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size_t2i_without_accum = config.training.batch_size_t2i * accelerator.num_processes + total_batch_size_t2i = ( + config.training.batch_size_t2i * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Data for generation + if config.dataset.gen_type == "t2i": + dataset = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + tokenizer=None, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_t2i, + per_gpu_batch_size=config.training.batch_size_t2i, + global_batch_size=total_batch_size_t2i_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + ) + train_dataloader_t2i = dataset.train_dataloader + num_update_steps_per_epoch = math.ceil( + train_dataloader_t2i.num_batches / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + elif config.dataset.gen_type == "t2i_parquet": + # this part relies on the internal packages, which will not be released + num_update_steps_per_epoch = math.ceil(config.experiment.max_train_examples_t2i / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + train_dataloader_t2i = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + batch_size=config.training.batch_size_t2i, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size + ) + + elif config.dataset.gen_type == "imagenet1k": + dataset_imagenet = ImageNetDataset( + dataset_config.train_t2i_shards_path_or_url, + image_size=preproc_config.resolution, + ) + + print('process index : ', + accelerator.process_index, ', ', accelerator.num_processes, + "Length: ", len(dataset_imagenet)) + + if accelerator.num_processes > 1: + sampler = DistributedSampler(dataset_imagenet, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + ) + shuffle = False + else: + sampler = None + shuffle = True + + train_dataloader_t2i = DataLoader(dataset_imagenet, batch_size=config.training.batch_size_t2i, + sampler=sampler, collate_fn=dataset_imagenet.collate_fn, + shuffle=shuffle, num_workers=dataset_config.num_workers) + num_update_steps_per_epoch = math.ceil(len(dataset_imagenet) / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + else: + raise ValueError(f"Unsupported dataset type {config.dataset.type}") + + + total_batch_size_mmu_without_accum = config.training.batch_size_mmu * accelerator.num_processes + # Data for image captioning + if config.dataset.und_type == "captioning": + dataset_mmu = xr( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + tokenizer=None, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_mmu, + per_gpu_batch_size=config.training.batch_size_mmu, + global_batch_size=total_batch_size_mmu_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + is_captioning=True, + add_caption_prompt=dataset_config.add_caption_prompt, + ) + train_dataloader_mmu = dataset_mmu.train_dataloader + + elif config.dataset.und_type == "captioning_parquet": + train_dataloader_mmu = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + batch_size=config.training.batch_size_mmu, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + is_captioning=True + ) + + + else: + raise NotImplementedError(f"Unsupported dataset type {config.dataset.und_type}") + + # LLM pure text dataset: RefinedWeb + dataset_lm = RefinedWebDataset(data_path=dataset_config.train_lm_shards_path_or_url, + rank=accelerator.process_index, + world_size=accelerator.num_processes, + num_workers=dataset_config.num_workers) + + train_dataloader_lm = torch.utils.data.DataLoader(dataset_lm, batch_size=config.training.batch_size_lm, + sampler=None, collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers) + + # Combine these dataloaders into a single iterable model + iterables = { + "t2i_flow": train_dataloader_t2i, + "lm_flow": train_dataloader_lm, + "mmu_flow": train_dataloader_mmu, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + # weight_map=None, + # load_state_dict_fn="safetensors" + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + vq_model.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + ): + + image_tokens = vq_model.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in range(first_epoch, num_train_epochs): + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + # for loss calculation + batch_size_t2i = batch["t2i_flow"]["images"].shape[0] + batch_size_lm = len(batch["lm_flow"]["input_ids"]) + batch_size_mmu = batch["mmu_flow"]["images"].shape[0] + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for class-conditional/text-to-image generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + pixel_values, texts = batch["t2i_flow"]["images"], batch["t2i_flow"]["input_ids"] + pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + data_time_m.update(time.time() - end) + + # Encode images to image tokens, mask them and create input and labels + ( + input_ids, + labels, + mask_prob, + image_tokens_ori, + t2i_masks + ) = prepare_inputs_and_labels(pixel_values, texts, config.training.min_masking_rate) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for language modeling + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + max_seq_len = input_ids.shape[-1] + texts_lm = batch["lm_flow"]["input_ids"] + ( + input_ids_lm, + labels_lm, + p_mask_lm + ) = prepare_inputs_and_labels_for_text(texts_lm, max_seq_len) + input_ids = torch.cat((input_ids, input_ids_lm.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_lm.to(input_ids.device)), dim=0) + + + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for captioning/multimodal understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if "llava" in config.dataset.und_type: + pixel_values_mmu, input_ids_mmu, labels_mmu = (batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"],batch["mmu_flow"]["labels"]) + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to( + accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to( + accelerator.device), + image_tokens_mmu, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to( + accelerator.device), + input_ids_mmu, + ], dim=1).long() + + labels_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + torch.ones_like(image_tokens_mmu) * uni_prompting.ignore_id, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + labels_mmu.to(accelerator.device) + ], dim=1).long() + + else: + pixel_values_mmu, texts_mmu = batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"] + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu, prompt_masks, labels_mmu = uni_prompting((image_tokens_mmu, texts_mmu), 'mmu') + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks, labels_mmu) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + + input_ids = torch.cat((input_ids, input_ids_mmu.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_mmu.to(input_ids.device)), dim=0) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + answer_lengths=answer_lengths, + t2i_masks=t2i_masks + ) + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss_t2i = accelerator.gather(loss_t2i.repeat(config.training.batch_size_t2i)).mean() + avg_loss_lm = accelerator.gather(loss_lm.repeat(config.training.batch_size_lm)).mean() + avg_loss_mmu = accelerator.gather(loss_mmu.repeat(config.training.batch_size_mmu)).mean() + loss = config.training.t2i_coeff * loss_t2i + \ + config.training.lm_coeff * loss_lm + \ + config.training.mmu_coeff * loss_mmu + + avg_masking_rate = accelerator.gather(mask_prob.repeat(config.training.batch_size_t2i)).mean() + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "step_loss_t2i": avg_loss_t2i.item(), + "step_loss_mmu": avg_loss_mmu.item(), + "step_loss_lm": avg_loss_lm.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + if ((global_step + 1) % config.experiment.generate_every == 0 or global_step == 0) and accelerator.is_main_process: + generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + force_no_cfg=False + ) + + generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + force_no_cfg=True + ) + + visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step + 1, + input_ids, + image_tokens_ori, + batch["t2i_flow"]["images"], + texts, + logits, + accelerator + ) + + understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + ) + + global_step += 1 + + # Stop training if max steps is reached + if global_step >= config.training.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +@torch.no_grad() +def visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step, + input_ids, + image_tokens_ori, + ori_images, + texts, + logits, + accelerator +): + logger.info("Visualizing predictions...") + model.eval() + + recons_images = vq_model.decode_code(image_tokens_ori - len(uni_prompting.text_tokenizer)) + recons_images = torch.clamp((recons_images + 1.0) / 2.0, min=0.0, max=1.0) + recons_images *= 255.0 + recons_images = recons_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + + images = torch.clamp((ori_images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predictions = logits[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:, len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens: len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens + config.model.mmada.codebook_size] + predictions = predictions.argmax(axis=-1) + # mask_token_id = config.model.mmada.vocab_size - 1 - len(uni_prompting.text_tokenizer) + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id - len(uni_prompting.text_tokenizer) + input_ids = input_ids[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:] - len(uni_prompting.text_tokenizer) + mask_ratio = list((torch.where(input_ids == mask_token_id, 1, 0).sum( + dim=-1) / config.model.mmada.num_vq_tokens).cpu().numpy()) + predicted_images = torch.where(input_ids == mask_token_id, predictions, input_ids) + predicted_images = vq_model.decode_code(predicted_images) + predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0) + predicted_images *= 255.0 + predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predicted_images = np.concatenate((images, recons_images, predicted_images), 2) + pil_images = [Image.fromarray(image) for image in predicted_images] + + # Log images + wandb_images = [wandb.Image(image, caption=f'mask ratio: {r:0.2f} \n caption: {texts[i]}') for i, (image, r) in + enumerate(zip(pil_images, mask_ratio))] + wandb.log({"Original images v.s. Reconstructed images v.s. Predicted images": wandb_images}, step=global_step) + + model.train() + + +@torch.no_grad() +def generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, + force_no_cfg = False +): + logger.info("Generating images...") + model.eval() + + # read validation prompts from file + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + mask_dtype = model.get_input_embeddings().weight.dtype + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + image_tokens = torch.ones((len(validation_prompts), config.model.mmada.num_vq_tokens), dtype=torch.long, + device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') + if not force_no_cfg and config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids = None + uncond_attention_mask = None + cfg_scale = 0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + # Generate images + gen_token_ids = accelerator.unwrap_model(model).t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + + model.train() + + if config.training.get("pre_encode", False): + del vq_model + + # Convert to PIL images + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=validation_prompts[i]) for i, image in enumerate(pil_images)] + wandb.log({f"Generated images with cfg {cfg_scale}": wandb_images}, step=global_step) + + + +@torch.no_grad() +def understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, +): + logger.info("Understanding images...") + model.eval() + + file_list = os.listdir(config.dataset.params.mmu_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + responses = ['' for i in range(len(file_list))] + images = [] + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + for i, file_name in enumerate(file_list): + image_path = os.path.join(config.dataset.params.mmu_image_root, file_name) + image_ori = Image.open(image_path).convert("RGB") + image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) + image = image.unsqueeze(0) + images.append(image) + image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + batch_size = 1 + + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." +'<|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + image_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate(input_ids) + # output_ids = torch.stack(output_ids).squeeze()[None] + + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + model.train() + images = torch.cat(images, dim=0) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] + wandb.log({"Understanding images": wandb_images}, step=global_step) + + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_mmada_stage3.py b/MMaDA/training/train_mmada_stage3.py new file mode 100644 index 0000000000000000000000000000000000000000..674a10e1484b426715cb7898bd6c4ae228843c35 --- /dev/null +++ b/MMaDA/training/train_mmada_stage3.py @@ -0,0 +1,1105 @@ +# Copyright 2025 MMaDA Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +from pathlib import Path +from typing import Union + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting + +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + return VQ_16 + else: + raise ValueError(f"model_type {model_type} not supported.") + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu) + total_batch_size = ( + (config.training.batch_size_t2i + config.training.batch_size_lm + config.training.batch_size_mmu) + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>" + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + if config.model.vq_model.get("pretrained_model_path", None): + vq_model = vq_model().to(accelerator.device) + state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] + vq_model.load_state_dict(state_dict) + else: + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size_t2i_without_accum = config.training.batch_size_t2i * accelerator.num_processes + total_batch_size_t2i = ( + config.training.batch_size_t2i * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Data for generation + if config.dataset.gen_type == "t2i": + dataset = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + tokenizer=uni_prompting.text_tokenizer, # we want to get raw texts, tokenizer is just for length counting + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_t2i, + per_gpu_batch_size=config.training.batch_size_t2i, + global_batch_size=total_batch_size_t2i_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + external_text_to_image_2M_512_caption_path=dataset_config.external_text_to_image_2M_512_caption_path, + ) + train_dataloader_t2i = dataset.train_dataloader + num_update_steps_per_epoch = math.ceil( + train_dataloader_t2i.num_batches / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + elif config.dataset.gen_type == "t2i_parquet": + # this part relies on the internal packages, which will not be released + num_update_steps_per_epoch = math.ceil(config.experiment.max_train_examples_t2i / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + train_dataloader_t2i = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + batch_size=config.training.batch_size_t2i, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size + ) + + elif config.dataset.gen_type == "imagenet1k": + dataset_imagenet = ImageNetDataset( + dataset_config.train_t2i_shards_path_or_url, + image_size=preproc_config.resolution, + ) + + print('process index : ', + accelerator.process_index, ', ', accelerator.num_processes, + "Length: ", len(dataset_imagenet)) + + if accelerator.num_processes > 1: + sampler = DistributedSampler(dataset_imagenet, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + ) + shuffle = False + else: + sampler = None + shuffle = True + + train_dataloader_t2i = DataLoader(dataset_imagenet, batch_size=config.training.batch_size_t2i, + sampler=sampler, collate_fn=dataset_imagenet.collate_fn, + shuffle=shuffle, num_workers=dataset_config.num_workers) + num_update_steps_per_epoch = math.ceil(len(dataset_imagenet) / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + else: + raise ValueError(f"Unsupported dataset type {config.dataset.type}") + + total_batch_size_mmu_without_accum = config.training.batch_size_mmu * accelerator.num_processes + # Data for image captioning + if config.dataset.und_type == "captioning": + dataset_mmu = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + tokenizer=uni_prompting.text_tokenizer, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_mmu, + per_gpu_batch_size=config.training.batch_size_mmu, + global_batch_size=total_batch_size_mmu_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + external_text_to_image_2M_512_caption_path=dataset_config.external_text_to_image_2M_512_caption_path, + is_captioning=True, + add_caption_prompt=dataset_config.add_caption_prompt, + ) + train_dataloader_mmu = dataset_mmu.train_dataloader + + elif config.dataset.und_type == "captioning_parquet": + train_dataloader_mmu = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + batch_size=config.training.batch_size_mmu, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + is_captioning=True + ) + else: + raise NotImplementedError(f"Unsupported dataset type {config.dataset.und_type}") + + dataset_lm = ChatDataset(data_path=dataset_config.train_lm_shards_path_or_url, + rank=accelerator.process_index, + world_size=accelerator.num_processes, + num_workers=dataset_config.num_workers, + max_length=preproc_config.max_seq_length, + tokenizer=uni_prompting.text_tokenizer, + ) + + train_dataloader_lm = torch.utils.data.DataLoader(dataset_lm, batch_size=config.training.batch_size_lm, + sampler=None, collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers) + + # Combine these dataloaders into a single iterable model + iterables = { + "t2i_flow": train_dataloader_t2i, + "lm_flow": train_dataloader_lm, + "mmu_flow": train_dataloader_mmu, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + # weight_map=None, + # load_state_dict_fn="safetensors" + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + vq_model.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, list[str]], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + seed=seed + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + @torch.no_grad() + def prepare_inputs_and_labels_for_chat_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm_chat') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_mask.bool()] = input_ids_lm[prompt_mask.bool()] + masked_indices = noisy_batch == mask_id + answer_lengths_lm = torch.sum((1 - prompt_mask), dim=-1, keepdim=True) + answer_lengths_lm = answer_lengths_lm.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_lm, p_mask, answer_lengths_lm + + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in range(first_epoch, num_train_epochs): + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + # for loss calculation + batch_size_t2i = batch["t2i_flow"]["images"].shape[0] + batch_size_lm = len(batch["lm_flow"]["input_ids"]) + batch_size_mmu = batch["mmu_flow"]["images"].shape[0] + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for class-conditional/text-to-image generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + pixel_values, texts = batch["t2i_flow"]["images"], batch["t2i_flow"]["input_ids"] + pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + data_time_m.update(time.time() - end) + + # Encode images to image tokens, mask them and create input and labels + ( + input_ids, + labels, + mask_prob, + image_tokens_ori, + t2i_masks + ) = prepare_inputs_and_labels(pixel_values, texts, config.training.min_masking_rate) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for language modeling + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + max_seq_len = input_ids.shape[-1] + texts_lm = batch["lm_flow"]["input_ids"] + ( + input_ids_lm, + labels_lm, + p_mask_lm, + answer_lengths_lm + ) = prepare_inputs_and_labels_for_chat_text(texts_lm, max_seq_len) + input_ids = torch.cat((input_ids, input_ids_lm.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_lm.to(input_ids.device)), dim=0) + + + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for captioning/multimodal understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if "llava" in config.dataset.und_type: + pixel_values_mmu, input_ids_mmu, labels_mmu = (batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"],batch["mmu_flow"]["labels"]) + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to( + accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to( + accelerator.device), + image_tokens_mmu, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to( + accelerator.device), + input_ids_mmu, + ], dim=1).long() + + labels_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + torch.ones_like(image_tokens_mmu) * uni_prompting.ignore_id, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + labels_mmu.to(accelerator.device) + ], dim=1).long() + + else: + pixel_values_mmu, texts_mmu = batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"] + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu, prompt_masks, labels_mmu = uni_prompting((image_tokens_mmu, texts_mmu), 'mmu') + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks, labels_mmu) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + + input_ids = torch.cat((input_ids, input_ids_mmu.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_mmu.to(input_ids.device)), dim=0) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + answer_lengths=answer_lengths, + t2i_masks=t2i_masks, + answer_lengths_lm=answer_lengths_lm + ) + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss_t2i = accelerator.gather(loss_t2i.repeat(config.training.batch_size_t2i)).mean() + avg_loss_lm = accelerator.gather(loss_lm.repeat(config.training.batch_size_lm)).mean() + avg_loss_mmu = accelerator.gather(loss_mmu.repeat(config.training.batch_size_mmu)).mean() + loss = config.training.t2i_coeff * loss_t2i + \ + config.training.lm_coeff * loss_lm + \ + config.training.mmu_coeff * loss_mmu + + avg_masking_rate = accelerator.gather(mask_prob.repeat(config.training.batch_size_t2i)).mean() + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "step_loss_t2i": avg_loss_t2i.item(), + "step_loss_mmu": avg_loss_mmu.item(), + "step_loss_lm": avg_loss_lm.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + if ((global_step + 1) % config.experiment.generate_every == 0 or global_step == start_step) and accelerator.is_main_process: + generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + force_no_cfg=False + ) + + generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + force_no_cfg=True + ) + + visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step + 1, + input_ids, + image_tokens_ori, + batch["t2i_flow"]["images"], + texts, + logits, + accelerator + ) + + understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + ) + + generate_chat_text( + model, + uni_prompting, + accelerator, + config, + global_step + 1, + ) + + global_step += 1 + # Stop training if max steps is reached + if global_step >= config.training.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +@torch.no_grad() +def visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step, + input_ids, + image_tokens_ori, + ori_images, + texts, + logits, + accelerator +): + logger.info("Visualizing predictions...") + model.eval() + + recons_images = vq_model.decode_code(image_tokens_ori - len(uni_prompting.text_tokenizer)) + recons_images = torch.clamp((recons_images + 1.0) / 2.0, min=0.0, max=1.0) + recons_images *= 255.0 + recons_images = recons_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + + images = torch.clamp((ori_images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predictions = logits[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:, len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens: len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens + config.model.mmada.codebook_size] + predictions = predictions.argmax(axis=-1) + # mask_token_id = config.model.mmada.vocab_size - 1 - len(uni_prompting.text_tokenizer) + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id - len(uni_prompting.text_tokenizer) + input_ids = input_ids[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:] - len(uni_prompting.text_tokenizer) + mask_ratio = list((torch.where(input_ids == mask_token_id, 1, 0).sum( + dim=-1) / config.model.mmada.num_vq_tokens).cpu().numpy()) + predicted_images = torch.where(input_ids == mask_token_id, predictions, input_ids) + predicted_images = vq_model.decode_code(predicted_images) + predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0) + predicted_images *= 255.0 + predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predicted_images = np.concatenate((images, recons_images, predicted_images), 2) + pil_images = [Image.fromarray(image) for image in predicted_images] + + # Log images + wandb_images = [wandb.Image(image, caption=f'mask ratio: {r:0.2f} \n caption: {texts[i]}') for i, (image, r) in + enumerate(zip(pil_images, mask_ratio))] + wandb.log({"Original images v.s. Reconstructed images v.s. Predicted images": wandb_images}, step=global_step) + + model.train() + + +@torch.no_grad() +def generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, + force_no_cfg = False +): + logger.info("Generating images...") + model.eval() + + # read validation prompts from file + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + mask_dtype = model.get_input_embeddings().weight.dtype + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + image_tokens = torch.ones((len(validation_prompts), config.model.mmada.num_vq_tokens), dtype=torch.long, + device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') + if not force_no_cfg and config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids = None + uncond_attention_mask = None + cfg_scale = 0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + # Generate images + gen_token_ids = accelerator.unwrap_model(model).t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + + model.train() + + if config.training.get("pre_encode", False): + del vq_model + + # Convert to PIL images + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=validation_prompts[i]) for i, image in enumerate(pil_images)] + wandb.log({f"Generated images with cfg {cfg_scale}": wandb_images}, step=global_step) + + + +@torch.no_grad() +def understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, +): + logger.info("Understanding images...") + model.eval() + + file_list = os.listdir(config.dataset.params.mmu_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + responses = ['' for i in range(len(file_list))] + images = [] + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + for i, file_name in enumerate(file_list): + image_path = os.path.join(config.dataset.params.mmu_image_root, file_name) + image_ori = Image.open(image_path).convert("RGB") + image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) + image = image.unsqueeze(0) + images.append(image) + image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + batch_size = 1 + + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." +'<|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + image_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate(input_ids, max_new_tokens=config.dataset.preprocessing.max_seq_length, steps=config.dataset.preprocessing.max_seq_length // 2, block_length=config.dataset.preprocessing.max_seq_length // 4) + # output_ids = torch.stack(output_ids).squeeze()[None] + + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + model.train() + images = torch.cat(images, dim=0) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] + wandb.log({"Understanding images": wandb_images}, step=global_step) + +@torch.no_grad() +def generate_chat_text( + model, + uni_prompting, + accelerator, + config, + global_step, +): + logger.info("Generating chat text...") + model.eval() + + # čÆ»å–ę•°ę®ļ¼ŒčŽ·å– prompt åˆ—č”Ø + df = pandas.read_json(config.dataset.params.lm_chat_validation_jsonl, lines=True) + prompts = df['question'].tolist() + responses = [''] * len(prompts) + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + # ē“Æē§Æę‰€ęœ‰ prompt/response åÆ¹ēš„ HTML å†…å®¹ļ¼Œä½œäøŗäø€äøŖę•“ä½“ log 到 wandb + html_content = "
    " + html_content += f"

    Step {global_step}

    " + + for i, prompt in enumerate(prompts): + # 原始 prompt ē”ØäŗŽå±•ē¤ŗ + original_prompt = prompt + + # ęž„é€ ē”Ÿęˆč¾“å…„ + prompt_with_tags = "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt}" + "<|start_header_id|>assistant<|end_header_id|>\n" + input_ids = uni_prompting.text_tokenizer([prompt_with_tags])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate( + input_ids, + max_new_tokens=config.dataset.preprocessing.max_seq_length, + steps=config.dataset.preprocessing.max_seq_length // 2, + block_length=config.dataset.preprocessing.max_seq_length // 4 + ) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + + # å°†ęÆäø€ē»„ prompt 和 response ēš„å±•ē¤ŗäæ”ęÆę·»åŠ åˆ° HTML äø­ + html_content += f""" +
    +

    Prompt

    +

    {original_prompt}

    +

    Response

    +

    {responses[i]}

    +
    + """ + + html_content += "
    " # ē»“ęŸę•“ä½“å®¹å™Ø + + model.train() + + # åœØäø€äøŖ step å†…ē»Ÿäø€ log ē”Ÿęˆēš„å•äøŖ HTML åÆ¹č±”ļ¼ˆčæ™ę ·å°±äøä¼šå¤šę¬” log åŒäø€äøŖ step) + wandb.log({"chat_text_generation": wandb.Html(html_content)}, step=global_step) + + + # # ę‰“å°ę‰€ęœ‰é—®ē­”åÆ¹ + # logger.info("\n===== chat generated =====") + # for i, (prompt, response) in enumerate(zip(prompts, responses)): + # logger.info(f"\nQuestion {i+1}:{prompt}") + # logger.info(f"\nAnswer {i+1}:{response}") + # logger.info("-" * 50) + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + + + + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_mmada_stage4.py b/MMaDA/training/train_mmada_stage4.py new file mode 100644 index 0000000000000000000000000000000000000000..03691b48075c6d67cc1a65c45e7d9b6a742c6904 --- /dev/null +++ b/MMaDA/training/train_mmada_stage4.py @@ -0,0 +1,1333 @@ + +# Copyright 2025 MMaDA Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import html +import random +from pathlib import Path +from typing import Union + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform, image_transform_squash +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset, VQADataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter +from torchmetrics.functional.multimodal import clip_score +from functools import partial +import ImageReward as RM +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + return VQ_16 + else: + raise ValueError(f"model_type {model_type} not supported.") + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu) + total_batch_size = ( + (config.training.batch_size_t2i + config.training.batch_size_lm + config.training.batch_size_mmu) + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>" + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + if config.model.vq_model.get("pretrained_model_path", None): + vq_model = vq_model().to(accelerator.device) + state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] + vq_model.load_state_dict(state_dict) + else: + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size_t2i_without_accum = config.training.batch_size_t2i * accelerator.num_processes + total_batch_size_t2i = ( + config.training.batch_size_t2i * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Data for generation + if config.dataset.gen_type == "t2i": + dataset = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + tokenizer=uni_prompting.text_tokenizer, # we want to get raw texts, tokenizer is just for length counting + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_t2i, + per_gpu_batch_size=config.training.batch_size_t2i, + global_batch_size=total_batch_size_t2i_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + external_text_to_image_2M_512_caption_path=dataset_config.external_text_to_image_2M_512_caption_path, + ) + train_dataloader_t2i = dataset.train_dataloader + num_update_steps_per_epoch = math.ceil( + train_dataloader_t2i.num_batches / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + elif config.dataset.gen_type == "t2i_parquet": + # this part relies on the internal packages, which will not be released + num_update_steps_per_epoch = math.ceil(config.experiment.max_train_examples_t2i / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + train_dataloader_t2i = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + batch_size=config.training.batch_size_t2i, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size + ) + + elif config.dataset.gen_type == "imagenet1k": + dataset_imagenet = ImageNetDataset( + dataset_config.train_t2i_shards_path_or_url, + image_size=preproc_config.resolution, + ) + + print('process index : ', + accelerator.process_index, ', ', accelerator.num_processes, + "Length: ", len(dataset_imagenet)) + + if accelerator.num_processes > 1: + sampler = DistributedSampler(dataset_imagenet, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + ) + shuffle = False + else: + sampler = None + shuffle = True + + train_dataloader_t2i = DataLoader(dataset_imagenet, batch_size=config.training.batch_size_t2i, + sampler=sampler, collate_fn=dataset_imagenet.collate_fn, + shuffle=shuffle, num_workers=dataset_config.num_workers) + num_update_steps_per_epoch = math.ceil(len(dataset_imagenet) / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + else: + raise ValueError(f"Unsupported dataset type {config.dataset.type}") + + total_batch_size_mmu_without_accum = config.training.batch_size_mmu * accelerator.num_processes + # Data for image captioning + if config.dataset.und_type == "captioning": + dataset_mmu = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + tokenizer=uni_prompting.text_tokenizer, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_mmu, + per_gpu_batch_size=config.training.batch_size_mmu, + global_batch_size=total_batch_size_mmu_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + external_text_to_image_2M_512_caption_path=dataset_config.external_text_to_image_2M_512_caption_path, + external_ai2d_caption_path=dataset_config.external_ai2d_caption_path, + external_clevr_caption_path=dataset_config.external_clevr_caption_path, + external_docvqa_caption_path=dataset_config.external_docvqa_caption_path, + external_geo_caption_path=dataset_config.external_geo_caption_path, + is_captioning=True, + add_caption_prompt=dataset_config.add_caption_prompt, + ) + train_dataloader_mmu = dataset_mmu.train_dataloader + + elif config.dataset.und_type == "captioning_parquet": + train_dataloader_mmu = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + batch_size=config.training.batch_size_mmu, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + is_captioning=True + ) + + else: + raise NotImplementedError(f"Unsupported dataset type {config.dataset.und_type}") + + # LLM pure text dataset: RefinedWeb + dataset_lm = RefinedWebDataset(data_path=dataset_config.train_lm_shards_path_or_url, + rank=accelerator.process_index, + world_size=accelerator.num_processes) + train_dataloader_lm = torch.utils.data.DataLoader(dataset_lm, batch_size=config.training.batch_size_lm, + sampler=None, collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers) + + dataset_instruct = ChatDataset(data_path=dataset_config.train_instruct_shards_path_or_url, + rank=accelerator.process_index, + world_size=accelerator.num_processes, + max_length=preproc_config.max_lm_text_length, + tokenizer=uni_prompting.text_tokenizer, + ) + + train_dataloader_instruct = torch.utils.data.DataLoader(dataset_instruct, batch_size=config.training.batch_size_lm, + sampler=None, collate_fn=dataset_instruct.collate_fn, + num_workers=dataset_config.num_workers) + + dataset_vqa = VQADataset( + json_path=dataset_config.external_vqa_caption_path, + image_root=dataset_config.vqa_images_path, + tokenizer=uni_prompting.text_tokenizer, + rank=accelerator.process_index, + world_size=accelerator.num_processes, + resolution=preproc_config.resolution, + max_length=preproc_config.max_seq_length + ) + train_dataloader_vqa = torch.utils.data.DataLoader(dataset_vqa, batch_size=config.training.batch_size_mmu, + sampler=None, collate_fn=dataset_vqa.collate_fn, + num_workers=dataset_config.num_workers) + + dataset_clevr2 = VQADataset( + json_path=dataset_config.external_clevr2_caption_path, + image_root=dataset_config.clevr2_images_path, + tokenizer=uni_prompting.text_tokenizer, + rank=accelerator.process_index, + world_size=accelerator.num_processes, + resolution=preproc_config.resolution, + max_length=preproc_config.max_seq_length + ) + train_dataloader_clevr2 = torch.utils.data.DataLoader(dataset_clevr2, batch_size=config.training.batch_size_mmu, + sampler=None, collate_fn=dataset_clevr2.collate_fn, + num_workers=dataset_config.num_workers) + + dataset_geo170k = VQADataset( + json_path=dataset_config.external_geo170k_caption_path, + image_root=dataset_config.geo170k_images_path, + tokenizer=uni_prompting.text_tokenizer, + rank=accelerator.process_index, + world_size=accelerator.num_processes, + resolution=preproc_config.resolution, + max_length=preproc_config.max_seq_length, + image_transform_method = "pad" + ) + train_dataloader_geo170k = torch.utils.data.DataLoader(dataset_geo170k, batch_size=config.training.batch_size_mmu, + sampler=None, collate_fn=dataset_geo170k.collate_fn, + num_workers=dataset_config.num_workers) + + # Combine these dataloaders into a single iterable model + iterables = { + "t2i_flow": train_dataloader_t2i, + "lm_flow": train_dataloader_lm, + "instruct_flow": train_dataloader_instruct, + "mmu_flow": train_dataloader_mmu, + "vqa_flow": train_dataloader_vqa, + "clevr2_flow": train_dataloader_clevr2, + "geo170k_flow": train_dataloader_geo170k, + } + + # + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + vq_model.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, list[str]], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + seed=seed + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, attention_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + answer_lengths_lm = torch.sum(attention_mask, dim=-1, keepdim=True) + answer_lengths_lm = answer_lengths_lm.clamp(min=1) + answer_lengths_lm = answer_lengths_lm.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_lm, p_mask, answer_lengths_lm + + @torch.no_grad() + def prepare_inputs_and_labels_for_chat_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm_chat') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_mask.bool()] = input_ids_lm[prompt_mask.bool()] + masked_indices = noisy_batch == mask_id + answer_lengths_lm = torch.sum((1 - prompt_mask), dim=-1, keepdim=True) + answer_lengths_lm = answer_lengths_lm.clamp(min=1) + answer_lengths_lm = answer_lengths_lm.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_lm, p_mask, answer_lengths_lm + + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in range(first_epoch, num_train_epochs): + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + + # for loss calculation + batch_size_t2i = batch["t2i_flow"]["images"].shape[0] + batch_size_lm = len(batch["lm_flow"]["input_ids"]) + batch_size_mmu = batch["mmu_flow"]["images"].shape[0] + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for class-conditional/text-to-image generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + pixel_values, texts = batch["t2i_flow"]["images"], batch["t2i_flow"]["input_ids"] + pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + data_time_m.update(time.time() - end) + + # Encode images to image tokens, mask them and create input and labels + ( + input_ids, + labels, + mask_prob, + image_tokens_ori, + t2i_masks + ) = prepare_inputs_and_labels(pixel_values, texts, config.training.min_masking_rate) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for language modeling + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + max_seq_len = input_ids.shape[-1] + + probs = [config.training.base_in_lm_coeff, config.training.instruct_in_lm_coeff] + probs_total = sum(probs) + probs = [p / probs_total for p in probs] + cum_probs = [sum(probs[:i+1]) for i in range(len(probs))] + rand_val = random.random() + if rand_val < cum_probs[0]: + texts_lm = batch["lm_flow"]["input_ids"] + ( + input_ids_lm, + labels_lm, + p_mask_lm, + answer_lengths_lm + ) = prepare_inputs_and_labels_for_text(texts_lm, max_seq_len) + input_ids = torch.cat((input_ids, input_ids_lm.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_lm.to(input_ids.device)), dim=0) + else: + texts_lm = batch["instruct_flow"]["input_ids"] + ( + input_ids_lm, + labels_lm, + p_mask_lm, + answer_lengths_lm + ) = prepare_inputs_and_labels_for_chat_text(texts_lm, max_seq_len) + input_ids = torch.cat((input_ids, input_ids_lm.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_lm.to(input_ids.device)), dim=0) + + + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for captioning/multimodal understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if "llava" in config.dataset.und_type: + pixel_values_mmu, input_ids_mmu, labels_mmu = (batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"],batch["mmu_flow"]["labels"]) + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to( + accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to( + accelerator.device), + image_tokens_mmu, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to( + accelerator.device), + input_ids_mmu, + ], dim=1).long() + + labels_mmu = torch.cat([ + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + torch.ones_like(image_tokens_mmu) * uni_prompting.ignore_id, + (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + labels_mmu.to(accelerator.device) + ], dim=1).long() + + else: + probs = [config.training.cot_in_mmu_coeff, config.training.vqa_in_mmu_coeff, config.training.clevr2_in_mmu_coeff, config.training.geo170k_in_mmu_coeff] + probs_total = sum(probs) + probs = [p / probs_total for p in probs] + cum_probs = [sum(probs[:i+1]) for i in range(len(probs))] + rand_val = random.random() + if rand_val < cum_probs[0]: + pixel_values_mmu, texts_mmu = batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"] + elif rand_val < cum_probs[1]: + pixel_values_mmu, texts_mmu = batch["vqa_flow"]["images"], batch["vqa_flow"]["input_ids"] + elif rand_val < cum_probs[2]: + pixel_values_mmu, texts_mmu = batch["clevr2_flow"]["images"], batch["clevr2_flow"]["input_ids"] + else: + pixel_values_mmu, texts_mmu = batch["geo170k_flow"]["images"], batch["geo170k_flow"]["input_ids"] + pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + image_tokens_mmu = vq_model.get_code(pixel_values_mmu) + image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + input_ids_mmu, prompt_masks, labels_mmu = uni_prompting((image_tokens_mmu, texts_mmu), 'mmu') + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks, labels_mmu) + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + + + input_ids = torch.cat((input_ids, input_ids_mmu.to(input_ids.device)), dim=0) + labels = torch.cat((labels, labels_mmu.to(input_ids.device)), dim=0) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + answer_lengths=answer_lengths, + t2i_masks=t2i_masks, + answer_lengths_lm=answer_lengths_lm + ) + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss_t2i = accelerator.gather(loss_t2i.repeat(config.training.batch_size_t2i)).mean() + avg_loss_lm = accelerator.gather(loss_lm.repeat(config.training.batch_size_lm)).mean() + avg_loss_mmu = accelerator.gather(loss_mmu.repeat(config.training.batch_size_mmu)).mean() + loss = config.training.t2i_coeff * loss_t2i + \ + config.training.lm_coeff * loss_lm + \ + config.training.mmu_coeff * loss_mmu + + avg_masking_rate = accelerator.gather(mask_prob.repeat(config.training.batch_size_t2i)).mean() + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "step_loss_t2i": avg_loss_t2i.item(), + "step_loss_mmu": avg_loss_mmu.item(), + "step_loss_lm": avg_loss_lm.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + if ((global_step + 1) % config.experiment.generate_every == 0 or global_step == start_step) and accelerator.is_main_process: + quantative_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + force_no_cfg=False + ) + + generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + mask_schedule=mask_schedule, + force_no_cfg=False + ) + + visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step + 1, + input_ids, + image_tokens_ori, + batch["t2i_flow"]["images"], + texts, + logits, + accelerator + ) + + understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + ) + + generate_chat_text( + model, + uni_prompting, + accelerator, + config, + global_step + 1, + ) + + global_step += 1 + # Stop training if max steps is reached + if global_step >= config.training.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +@torch.no_grad() +def visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step, + input_ids, + image_tokens_ori, + ori_images, + texts, + logits, + accelerator +): + logger.info("Visualizing predictions...") + model.eval() + + recons_images = vq_model.decode_code(image_tokens_ori - len(uni_prompting.text_tokenizer)) + recons_images = torch.clamp((recons_images + 1.0) / 2.0, min=0.0, max=1.0) + recons_images *= 255.0 + recons_images = recons_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + + images = torch.clamp((ori_images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predictions = logits[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:, len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens: len(uni_prompting.text_tokenizer) + config.model.mmada.num_new_special_tokens + config.model.mmada.codebook_size] + predictions = predictions.argmax(axis=-1) + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id - len(uni_prompting.text_tokenizer) + input_ids = input_ids[:config.training.batch_size_t2i, -(config.model.mmada.num_vq_tokens + 1):-1:] - len(uni_prompting.text_tokenizer) + mask_ratio = list((torch.where(input_ids == mask_token_id, 1, 0).sum( + dim=-1) / config.model.mmada.num_vq_tokens).cpu().numpy()) + predicted_images = torch.where(input_ids == mask_token_id, predictions, input_ids) + predicted_images = vq_model.decode_code(predicted_images) + predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0) + predicted_images *= 255.0 + predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predicted_images = np.concatenate((images, recons_images, predicted_images), 2) + pil_images = [Image.fromarray(image) for image in predicted_images] + + # Log images + wandb_images = [wandb.Image(image, caption=f'mask ratio: {r:0.2f} \n caption: {texts[i]}') for i, (image, r) in + enumerate(zip(pil_images, mask_ratio))] + wandb.log({"Original images v.s. Reconstructed images v.s. Predicted images": wandb_images}, step=global_step) + + model.train() + + +@torch.no_grad() +def generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, + force_no_cfg = False +): + logger.info("Generating images...") + model.eval() + + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + + mask_dtype = model.get_input_embeddings().weight.dtype + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + image_tokens = torch.ones((len(validation_prompts), config.model.mmada.num_vq_tokens), dtype=torch.long, + device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') + if not force_no_cfg and config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids = None + uncond_attention_mask = None + cfg_scale = 0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + # Generate images + gen_token_ids = accelerator.unwrap_model(model).t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + + model.train() + + if config.training.get("pre_encode", False): + del vq_model + + # Convert to PIL images + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=validation_prompts[i]) for i, image in enumerate(pil_images)] + wandb.log({f"Generated images with cfg {cfg_scale}": wandb_images}, step=global_step) + + + + + +@torch.no_grad() +def quantative_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, + force_no_cfg = False +): + logger.info("Quantative images...") + model.eval() + clip_score_fn = partial(clip_score, model_name_or_path="/data_storage/shared/pretrained_models/") + image_reward_model = RM.load("/data_storage/shared/pretrained_models/ImageReward/ImageReward.pt") + # read validation prompts from file + with open(config.validation.quantative_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + mask_dtype = model.get_input_embeddings().weight.dtype + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + image_tokens = torch.ones((len(validation_prompts), config.model.mmada.num_vq_tokens), dtype=torch.long, + device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') + if not force_no_cfg and config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids = None + uncond_attention_mask = None + cfg_scale = 0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + validation_batch_size = config.validation.quantative_batch_size + + pil_images = [] + clip_scores = [] + image_rewards = [] + for i in range(0, len(validation_prompts), validation_batch_size): + batch_input_ids = input_ids[i:i+validation_batch_size] + batch_attention_mask = attention_mask[i:i+validation_batch_size] + batch_uncond_input_ids = uncond_input_ids[i:i+validation_batch_size] + batch_uncond_attention_mask = uncond_attention_mask[i:i+validation_batch_size] + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + # Generate images + gen_token_ids = accelerator.unwrap_model(model).t2i_generate( + input_ids=batch_input_ids, + uncond_input_ids=batch_uncond_input_ids, + attention_mask=batch_attention_mask, + uncond_attention_mask=batch_uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=config.model.mmada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + image_tensor = images.to(torch.uint8) + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + batch_pil_images = [Image.fromarray(image) for image in images] + pil_images.extend(batch_pil_images) + + # calculate CLIP score + batch_clip_score = clip_score_fn(image_tensor, validation_prompts[i:i+validation_batch_size]) + # calculate image reward score + for j in range(validation_batch_size): + clip_scores.append(clip_score_fn(image_tensor[j], validation_prompts[i+j])) + image_reward_score = image_reward_model.score(validation_prompts[i+j], batch_pil_images[j]) + image_rewards.append(image_reward_score) + + clip_scores = torch.tensor(clip_scores) + image_rewards = torch.tensor(image_rewards) + logger.info(f"clip_scores: {clip_scores}, image_rewards: {image_rewards}") + clip_scores_mean = clip_scores.mean() + image_rewards_mean = image_rewards.mean() + logger.info(f"CLIP score mean: {clip_scores_mean}, Image reward score mean: {image_rewards_mean}") + accelerator.log({"clip_score": clip_scores_mean, "image_reward_score": image_rewards_mean}, step=global_step) + + + + # Log images + wandb_images = [wandb.Image(image, caption=f"{validation_prompts[i]} \n CLIP score: {clip_scores[i]}, Image reward score: {image_rewards[i]}") for i, image in enumerate(pil_images[:validation_batch_size])] + wandb.log({f"Quantative images with cfg {cfg_scale}": wandb_images}, step=global_step) + + + if config.training.get("pre_encode", False): + del vq_model + + model.train() + + + + + +@torch.no_grad() +def understanding_images( + model, + vq_model, + uni_prompting, # åŒ…å«äŗ† text_tokenizer + accelerator, + config, + global_step, +): + """ + Processes images and multi-turn conversation prompts for image understanding, + generates responses, and logs results to Weights & Biases. + Uses tokenizer.apply_chat_template for handling conversation history. + """ + logger.info("Understanding images (multi-turn)...") + model.eval() + prompts_file_path = config.dataset.params.mmu_validation_prompts_file + image_root = config.dataset.params.mmu_image_root + try: + with open(prompts_file_path, 'r', encoding='utf-8') as f: + validation_data = json.load(f) + logger.info(f"Successfully loaded {len(validation_data)} validation items from {prompts_file_path}") + except Exception as e: + logger.error(f"Error loading prompts from {prompts_file_path}: {e}. Skipping image understanding.") + model.train() + return + wandb_logs = [] + device = accelerator.device + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + for item in validation_data: + file_name = item.get('file_name') + messages = item.get('messages') + if not file_name or not messages: + logger.warning(f"Skipping item due to missing 'file_name' or 'messages': {item}") + continue + image_path = os.path.join(image_root, file_name) + if not os.path.exists(image_path): + logger.warning(f"Image file not found: {image_path}. Skipping.") + continue + try: + image_ori = Image.open(image_path).convert("RGB") + if any(tag in file_name for tag in ['ai2d', 'clevr', 'docvqa', 'geo', 'llava']): + image = image_transform_squash(image_ori, resolution=config.dataset.preprocessing.resolution).to(device) + else: + image = image_transform(image_ori, resolution=config.dataset.preprocessing.resolution).to(device) + image = image.unsqueeze(0) + image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + batch_size = image_tokens.shape[0] + text_token_ids = uni_prompting.text_tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt" + ).to(device) + input_ids = torch.cat([ + (torch.ones(batch_size, 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + (torch.ones(batch_size, 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + image_tokens, + (torch.ones(batch_size, 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + text_token_ids + ], dim=1).long() + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate( + input_ids, + max_new_tokens=config.dataset.preprocessing.max_seq_length, + steps=config.dataset.preprocessing.max_seq_length // 2, + block_length=config.dataset.preprocessing.max_seq_length // 4, + ) + generated_ids = output_ids[:, input_ids.shape[1]:] + response_text = uni_prompting.text_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + conversation_str = f"Image: {file_name}\n" + "="*20 + "\n" + conversation_str = f"Image: {file_name}\n" + "="*20 + "\n" + for msg in messages: + role_prefix = "User: " if msg['role'] == 'user' else "Assistant: " + conversation_str += f"{role_prefix}{msg['content']}\n" + conversation_str += f"Assistant (Generated): {response_text}\n" + log_image_tensor = torch.clamp((image.squeeze(0) + 1.0) / 2.0, min=0.0, max=1.0) * 255.0 + log_image_np = log_image_tensor.permute(1, 2, 0).cpu().numpy().astype(np.uint8) + pil_image = Image.fromarray(log_image_np) + wandb_logs.append(wandb.Image(pil_image, caption=conversation_str.strip())) + except Exception as e: + logger.error(f"Error processing {file_name}: {e}", exc_info=True) + if wandb_logs: + try: + wandb.log({"Understanding images (multi-turn)": wandb_logs}, step=global_step) + logger.info(f"Logged {len(wandb_logs)} understanding image results to W&B for step {global_step}.") + except Exception as e: + logger.error(f"Failed to log understanding images to W&B: {e}") + else: + logger.warning("No images were successfully processed for understanding in this step.") + model.train() + +@torch.no_grad() +def generate_chat_text( + model, + uni_prompting, + accelerator, + config, + global_step, +): + logger.info("Generating chat text...") + model.eval() + + df = pandas.read_json(config.dataset.params.lm_chat_validation_jsonl, lines=True) + prompts = df['question'].tolist() + responses = [''] * len(prompts) + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + html_content = "
    " + html_content += f"

    Step {global_step}

    " + + for i, prompt in enumerate(prompts): + original_prompt = prompt + + prompt_with_tags = "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt}" + "<|start_header_id|>assistant<|end_header_id|>\n" + token_ids = uni_prompting.text_tokenizer([prompt_with_tags])['input_ids'][0] + token_ids = [uni_prompting.text_tokenizer.bos_token_id] + token_ids + input_ids = torch.tensor(token_ids).unsqueeze(0).to(device) + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate( + input_ids, + max_new_tokens=config.dataset.preprocessing.max_seq_length, + steps=config.dataset.preprocessing.max_lm_text_length // 2, + block_length=config.dataset.preprocessing.max_seq_length // 4 + ) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + + escaped_prompt = html.escape(original_prompt) + escaped_response = html.escape(responses[i]) + html_content += f""" +
    +

    Prompt

    +

    {escaped_prompt}

    +

    Response

    +

    {escaped_response}

    +
    + """ + + html_content += "
    " + + model.train() + + wandb.log({"chat_text_generation": wandb.Html(html_content)}, step=global_step) + + + + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + + + + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_mmada_t2s.py b/MMaDA/training/train_mmada_t2s.py new file mode 100644 index 0000000000000000000000000000000000000000..45bb55ba2635a230d4c0bc1b25171edae142f240 --- /dev/null +++ b/MMaDA/training/train_mmada_t2s.py @@ -0,0 +1,617 @@ +# Copyright 2025 AIDAS Lab +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.filterwarnings("ignore") + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import shutil +import time +from pathlib import Path +from typing import Union, List + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ S2T-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset +# import librosa +# ++++++++++++++++++++++++++++++ + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.mmada.new_vocab_size}") + model.resize_token_embeddings(config.model.mmada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } +# ++++++++++++++++++++++++++++++++++ + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + ) + + total_batch_size_per_gpu = config.training.batch_size_s2t + total_batch_size = ( + config.training.batch_size_s2t + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + logger.info("="*50) + logger.info(f"max_train_steps from config: {config.training.max_train_steps}") + logger.info("="*50) + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", "<|t2s|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + logger.info('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.mmada.new_vocab_size}") + + if resized_vocab_size == config.model.mmada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps + max_train_steps_for_scheduler = config.training.max_train_steps + + # lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps * accelerator.num_processes + # max_train_steps_for_scheduler = config.training.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps_for_scheduler, + num_training_steps=max_train_steps_for_scheduler, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloader and lr_scheduler for T2S task") + train_dataset = SpeechTextDataset(config.dataset.data.name, config.dataset.data.subset, config.dataset.data.split) + eval_dataset = train_dataset + + logger.info(f"Dataset Prepared.") + + sampler = DistributedSampler(train_dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True) if accelerator.num_processes > 1 else None + + train_dataloader = DataLoader(train_dataset, batch_size=config.training.batch_size_s2t, shuffle=True, collate_fn=collate_fn, num_workers=config.dataset.params.num_workers) + eval_dataloader = DataLoader(eval_dataset, batch_size=config.training.batch_size_s2t, collate_fn=collate_fn, num_workers=config.dataset.params.num_workers) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + + # if config.experiment.resume_from_checkpoint: + # dirs = os.listdir(config.experiment.output_dir) + # logger.info(f"dirs: {dirs}") + # dirs = [d for d in dirs if d.startswith("checkpoint")] + # dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + # path = dirs[-1] if len(dirs) > 0 else None + # logger.info(f"path: {path}") + # if path is not None: + # path = os.path.join(config.experiment.output_dir, path) + # logger.info(f"Resuming from checkpoint: {path}") + # global_step = int(os.path.basename(path).split("-")[1]) + # first_epoch = global_step // num_update_steps_per_epoch + # if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + # state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + # model.load_state_dict(state_dict, strict=True) + # del state_dict + # elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + # from transformers.modeling_utils import load_sharded_checkpoint + # load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # # if safetensors sharded checkpoint exists + # elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + # from transformers.modeling_utils import load_sharded_checkpoint + # load_sharded_checkpoint( + # model, + # f'{path}/unwrapped_model/', + # ) + # else: + # raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + if config.experiment.resume_from_checkpoint: + output_dir = Path(config.experiment.output_dir) + if os.path.exists(output_dir): + checkpoint_dirs = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")] + if checkpoint_dirs: + latest_checkpoint = sorted(checkpoint_dirs, key=lambda x: int(x.split("-")[1]))[-1] + resume_path = os.path.join(output_dir, latest_checkpoint) + + logger.info(f"Resuming from checkpoint: {resume_path}") + + accelerator.load_state(resume_path) + + global_step = int(latest_checkpoint.split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + else: + logger.info("No checkpoint found to resume from, starting from scratch.") + else: + logger.info("Output directory does not exist, starting from scratch.") + else: + logger.info("Not resuming from checkpoint.") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloader") + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model.to(device=accelerator.device) + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num train examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, mask_id=126336, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + + for epoch in range(first_epoch, num_train_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + data_time_m.update(time.time() - end) + + batch_size_t2s = len(batch["audio_path"]) + audio_paths, texts_t2s = batch["audio_path"], batch["text"] + offset = len(uni_prompting.text_tokenizer) + config.model.mmada.codebook_size + + all_audio_tokens = [] + for path in audio_paths: + tokens = vq_model.encode(path) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 50256 + max_len = max(t.shape[1] for t in all_audio_tokens) + + padded_batch = [] + for tokens in all_audio_tokens: + num_padding = max_len - tokens.shape[1] + padding = torch.full((tokens.shape[0], num_padding), pad_token_id, dtype=tokens.dtype, device=tokens.device) + padded_tensor = torch.cat([tokens, padding], dim=1) + padded_batch.append(padded_tensor) + + audio_tokens_t2s = torch.cat(padded_batch, dim=0) + audio_tokens_t2s = audio_tokens_t2s.to(accelerator.device, non_blocking=True) + + input_ids_t2s, prompt_masks, labels_t2s = uni_prompting((texts_t2s, audio_tokens_t2s), 't2s') + + input_ids_t2s, labels, p_mask_t2s, answer_lengths = prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s + ) + + input_ids_t2s = input_ids_t2s.to(accelerator.device, non_blocking=True) + labels = labels.to(accelerator.device, non_blocking=True) + p_mask_t2s = p_mask_t2s.to(accelerator.device, non_blocking=True) + answer_lengths = answer_lengths.to(accelerator.device, non_blocking=True) + + logits, t2s_loss = accelerator.unwrap_model(model).forward_t2s( + input_ids=input_ids_t2s, + labels=labels, + batch_size_t2s=batch_size_t2s, + p_mask_t2s=p_mask_t2s, + answer_lengths=answer_lengths + ) + + # Gather the losses across all processes for logging + avg_loss_t2s = accelerator.gather(t2s_loss.repeat(batch_size_t2s)).mean() + + accelerator.backward(t2s_loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + avg_masking_rate = accelerator.gather(p_mask_t2s.mean()).mean() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * batch_size_t2s / batch_time_m.val + ) + logs = { + "step_loss_t2s": avg_loss_t2s.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_t2s: {avg_loss_t2s.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + # save_checkpoint( + # accelerator, + # config.experiment.output_dir, + # global_step + 1, + # uni_prompting, + # config.experiment.get("checkpoints_total_limit") + # ) + + # if (global_step + 1) % config.experiment.eval_every == 0 : + # run_evaluation(model, eval_dataloader, accelerator, global_step + 1) + # if accelerator.is_main_process: + # torch.cuda.empty_cache() + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + # save_checkpoint( + # accelerator, + # config.experiment.output_dir, + # global_step + 1, + # uni_prompting, + # config.experiment.get("checkpoints_total_limit") + # ) + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + +# def save_checkpoint(accelerator, output_dir, global_step, uni_prompting, checkpoints_total_limit=None): + +# if accelerator.is_main_process and checkpoints_total_limit is not None: +# checkpoints = os.listdir(output_dir) +# checkpoints = [d for d in checkpoints if d.startswith("checkpoint-")] + +# if len(checkpoints) >= checkpoints_total_limit: +# checkpoints_to_remove = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[:len(checkpoints) - checkpoints_total_limit + 1] +# logger.info(f"Removing {len(checkpoints_to_remove)} old checkpoints to keep a total of {checkpoints_total_limit}") +# for ckpt in checkpoints_to_remove: +# shutil.rmtree(os.path.join(output_dir, ckpt)) + +# save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + +# accelerator.save_state(save_path) +# logger.info(f"Saved complete state to {save_path}") + +# if accelerator.is_main_process: +# uni_prompting.text_tokenizer.save_pretrained(save_path) + +def log_grad_norm(model, accelerator, global_step): + if not accelerator.is_main_process: + return + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MMaDA/training/train_mmada_t2s_test.py b/MMaDA/training/train_mmada_t2s_test.py new file mode 100644 index 0000000000000000000000000000000000000000..568aa8b57d37ffccb49df414cb5d8b1319bf7ead --- /dev/null +++ b/MMaDA/training/train_mmada_t2s_test.py @@ -0,0 +1,390 @@ +# Copyright 2025 AIDAS Lab +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.filterwarnings("ignore") + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import shutil +import time +from pathlib import Path +from typing import Union, List + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +# +++++ S2T-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset +# ++++++++++++++++++++++++++++++ + +from training.utils import get_config, flatten_omega_conf +from models import get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error +from training.utils import AverageMeter + +logger = get_logger(__name__, log_level="INFO") + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.mmada.new_vocab_size}") + model.resize_token_embeddings(config.model.mmada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn(batch): + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # +++++ DEBUG PRINT +++++ + if config.model.mmada.get("train_step"): + logger.info("="*50) + logger.info(f"[DEBUG] Found 'model.mmada.train_step': {config.model.mmada.train_step}") + logger.info("[DEBUG] This value might be overriding 'training.max_train_steps'.") + logger.info("="*50) + # +++++++++++++++++++++++ + + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=str(Path(config.experiment.output_dir) / "logs"), + ) + + total_batch_size_per_gpu = config.training.batch_size_s2t + total_batch_size = ( + config.training.batch_size_s2t + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + if accelerator.is_main_process: + # ... (wandb setup code remains the same) ... + os.makedirs(config.experiment.output_dir, exist_ok=True) + OmegaConf.save(config, Path(config.experiment.output_dir) / "config.yaml") + + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", "<|t2s|>"), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + vq_model = get_vq_model_class(config.model.vq_model.type) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16) + + unwrapped_model = accelerator.unwrap_model(model) + resize_vocab(unwrapped_model, config) + + optimizer_config = config.optimizer.params + # ... (optimizer setup code remains the same) ... + optimizer = AdamW( + [ + { + "params": [p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"])], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"])], + "weight_decay": 0.0, + }, + ], + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + + ################################## + # DATALOADER # + ################################## + logger.info("Creating dataloader and lr_scheduler for T2S task") + train_dataset = SpeechTextDataset(config.dataset.data.name, config.dataset.data.subset, config.dataset.data.split) + + # +++++ DEBUG PRINT +++++ + logger.info(f"[DEBUG] Loaded dataset '{config.dataset.data.name}/{config.dataset.data.subset}'. Number of samples: {len(train_dataset)}") + # +++++++++++++++++++++++ + + train_dataloader = DataLoader(train_dataset, batch_size=config.training.batch_size_s2t, shuffle=True, collate_fn=collate_fn, num_workers=config.dataset.params.num_workers) + + # +++++ DEBUG PRINT +++++ + logger.info(f"[DEBUG] Dataloader created. Number of batches: {len(train_dataloader)}") + # +++++++++++++++++++++++ + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + # +++++ DEBUG PRINT +++++ + logger.info("="*50) + logger.info("[DEBUG] CONTROL FLOW CALCULATION:") + logger.info(f"[DEBUG] > config.training.max_train_steps: {config.training.max_train_steps}") + logger.info(f"[DEBUG] > num_update_steps_per_epoch: {num_update_steps_per_epoch}") + logger.info(f"[DEBUG] > Calculated num_train_epochs: {num_train_epochs}") + logger.info("="*50) + # +++++++++++++++++++++++ + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # MODEL RESUME # + ################################## + global_step = 0 + first_epoch = 0 + if config.experiment.resume_from_checkpoint: + # ... (resume logic) ... + # Simplified for clarity, assuming your logic is correct + # You should add prints inside your actual resume block + logger.info(f"[DEBUG] Resume logic will attempt to load state.") + # In your resume block, after loading: + # logger.info(f"[DEBUG] Resumed from checkpoint. global_step={global_step}, first_epoch={first_epoch}") + else: + logger.info("[DEBUG] Not resuming from checkpoint. Starting at step 0, epoch 0.") + + ################################## + # Prepare accelerator # + ################################## + logger.info("Preparing model, optimizer and dataloader") + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + # +++++ ADD THIS CRITICAL LOG +++++ + # This will show the real number of batches on each GPU + if accelerator.is_main_process: + logger.info("="*50) + logger.info(f"[DEBUG] TRUE DATALOADER LENGTH (per process) after accelerator.prepare: {len(train_dataloader)}") + logger.info(f"This means the inner loop will run this many times.") + logger.info(f"Expected global steps per epoch: {len(train_dataloader) // config.training.gradient_accumulation_steps}") + logger.info("="*50) + # ++++++++++++++++++++++++++++++++++ + + vq_model.to(device=accelerator.device) + + ################################## + # Training # + ################################## + logger.info("***** Running training *****") + # ... (other logs) ... + + # ... (prepare_inputs_and_labels_for_t2s function) ... + + for epoch in range(first_epoch, num_train_epochs): + # +++++ DEBUG PRINT +++++ + logger.info(f"[DEBUG] >>> Starting Epoch {epoch}/{num_train_epochs - 1} <<<") + # +++++++++++++++++++++++ + model.train() + for step, batch in enumerate(train_dataloader): + # ... (your forward pass logic) ... + + with accelerator.accumulate(model): + # ... (loss calculation, backward, optimizer step, etc.) + accelerator.backward(t2s_loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + # ... (logging logic) ... + if (global_step + 1) % 1000 == 0: # Log every 1000 steps to avoid spam + logger.info(f"[DEBUG] Loop heartbeat: global_step = {global_step}") + + global_step += 1 + + # Termination check + if global_step >= config.training.max_train_steps: + # +++++ DEBUG PRINT +++++ + logger.info("="*50) + logger.info(f"[DEBUG] BREAKING INNER LOOP: Condition 'global_step >= max_train_steps' is TRUE.") + logger.info(f"[DEBUG] > global_step: {global_step}") + logger.info(f"[DEBUG] > max_train_steps: {config.training.max_train_steps}") + logger.info("="*50) + # +++++++++++++++++++++++ + break + + # +++++ DEBUG PRINT +++++ + logger.info(f"[DEBUG] >>> Finished Epoch {epoch}. Current global_step = {global_step} <<<") + # +++++++++++++++++++++++ + + # After each epoch, check if we should stop the entire training + if global_step >= config.training.max_train_steps: + # +++++ DEBUG PRINT +++++ + logger.info("="*50) + logger.info(f"[DEBUG] BREAKING OUTER LOOP: 'max_train_steps' reached after epoch {epoch}.") + logger.info("="*50) + # +++++++++++++++++++++++ + break + + # +++++ DEBUG PRINT +++++ + logger.info("="*50) + logger.info(f"[DEBUG] Training loop finished. Final global_step = {global_step}.") + logger.info("[DEBUG] Proceeding to final save and shutdown.") + logger.info("="*50) + # +++++++++++++++++++++++ + + accelerator.wait_for_everyone() + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + # save_checkpoint( + # accelerator, + # config.experiment.output_dir, + # global_step + 1, + # uni_prompting, + # config.experiment.get("checkpoints_total_limit") + # ) + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + +# def save_checkpoint(accelerator, output_dir, global_step, uni_prompting, checkpoints_total_limit=None): + +# if accelerator.is_main_process and checkpoints_total_limit is not None: +# checkpoints = os.listdir(output_dir) +# checkpoints = [d for d in checkpoints if d.startswith("checkpoint-")] + +# if len(checkpoints) >= checkpoints_total_limit: +# checkpoints_to_remove = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[:len(checkpoints) - checkpoints_total_limit + 1] +# logger.info(f"Removing {len(checkpoints_to_remove)} old checkpoints to keep a total of {checkpoints_total_limit}") +# for ckpt in checkpoints_to_remove: +# shutil.rmtree(os.path.join(output_dir, ckpt)) + +# save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + +# accelerator.save_state(save_path) +# logger.info(f"Saved complete state to {save_path}") + +# if accelerator.is_main_process: +# uni_prompting.text_tokenizer.save_pretrained(save_path) + +def log_grad_norm(model, accelerator, global_step): + if not accelerator.is_main_process: + return + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MMaDA/training/train_mmada_v2s.py b/MMaDA/training/train_mmada_v2s.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0085c759dee10aa5d18f6f13f9efdca4e22db3 --- /dev/null +++ b/MMaDA/training/train_mmada_v2s.py @@ -0,0 +1,1067 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import cv2 +import glob +import random +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union +import csv +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset, MixedSpeechTextDataset, load_video_mp4, VideoCaptionDataset, S2T_INSTRUCTION, T2S_INSTRUCTION +# import librosa + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +V2S_INSTRUCTION = [ + "Generate speech that describes the given video.", + + "Narrate the events happening in the video.", + + "Produce spoken audio describing the video content.", + + "Convert the video into a detailed spoken narration.", + + "Speak a description of what is shown in the video.", + + "Synthesize speech that explains the content of the video." +] + + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + pass + # return VQ_16 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + + +def collate_fn_video_speech(batch): + frame_list = [] + speech_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + speech_list.append(item['speech']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + #input_ids = torch.stack(input_ids_list, dim=0) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "speech": speech_list # speech (B, seq_len) + } + + +class VideoSpeechDataset(Dataset): + def __init__( + self, + transform, + tokenizer, + max_seq_length: int, + + resolution: int = 256, + openvid1m_path = "/home/work/AIDAS/data/video/openvid1m/video", + dataset_name = "openvidspeech", + + sample_method='uniform', + num_frames: int = 8, + vq_model=None, + ): + + self.max_seq_length = max_seq_length + self.transform = transform + self.vq_model = vq_model + self.tokenizer = tokenizer + self.resolution = resolution + self.sample_method = sample_method + self.dataset_name = dataset_name + self.num_frames = num_frames + self.caption_prompt = V2S_INSTRUCTION + self.caption_prompt = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in self.caption_prompt] + + if dataset_name == 'openvidspeech': + self.audio_root = "/home/work/AIDAS/data/video-speech" + self.vid_root = "/home/work/AIDAS/data/video/openvid1m/video/video" + self.data_list = self.load_openvidspeech() + else: + raise ValueError(f"Invalid video-speech dataset name: {dataset_name}.") + + def load_openvidspeech(self): + audio_csv_path = osp.join(self.audio_root, "openvid-speech.csv") + with open(audio_csv_path, "r") as f: + wav_list = [ + line.strip().removesuffix(".wav") + for line in f + if line.strip() + ] + self.dataset_length = len(wav_list) + + dataset_list = [] + for wav_path in wav_list: + dict_item = {} + dict_item['video'] = osp.join(self.vid_root, wav_path+".mp4") + + dict_item['speech'] = osp.join(self.audio_root,"openvid-speech", wav_path+".wav") + dataset_list.append(dict_item) + + print(f"{self.dataset_length} video-speech files has been loaded.") + return dataset_list + + + def _get_caption_prompt(self): + """ + Get a random caption prompt from the list of caption prompts. + """ + return np.random.choice(self.caption_prompt) + + def _tokenize(self, text): + if self.tokenizer is not None: + input_ids = self.tokenizer( + text, + truncation=True, + max_length=2 * self.max_seq_length, + padding=False, + return_tensors="pt" + )[0] + + if len(input_ids) > self.max_seq_length: + return None + else: + return input_ids + else: + raise ValueError("Tokenizer is not provided.") + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, idx): + max_try_count = 50 + + for try_count in range(max_try_count): + try: + data = self._sample_data(idx) + if data is not None: + return { + "video": data["video"], + "speech": data["speech"] + } + except Exception as e: + print(f"Error loading data: {e}") + idx = random.randint(0, self.dataset_length - 1) + + return None + + + def _sample_data(self, idx): + if self.dataset_name == 'openvidspeech': + data_row = self.data_list[idx] + video_path = data_row["video"] + speech_path = data_row["speech"] + else: + raise ValueError(f"Invalid video-speech dataset name: {self.dataset_name}.") + + frames = load_video_mp4( + video_path=video_path, + sample_method=self.sample_method, + num_frames=self.num_frames, + resolution=self.resolution, + transform=self.transform + ) + + return { + "video": frames, # torch tensor (T, C, H, W) + "speech": speech_path + } + + +def load_video_mp4( + video_path, + sample_method='uniform', + num_frames: int = 8, + resolution: int = 256, + transform=None, + ): + """ + load video frames and then return it as a list of images + + args: + video_path: path to the video file + sample: sampling method, 'uniform' or 'random' + num_frames: number of frames to sample from the video + return: video frames as a list of images + """ + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Could not open video file {video_path}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(Image.fromarray(frame)) + + cap.release() + + total_frames = len(frames) + + if total_frames < num_frames: + raise ValueError(f"Video {video_path} has less than 8 frames, got {total_frames} frames.") + + if sample_method == 'uniform': + indices = np.linspace(0, total_frames - 1, num_frames).astype(int) + elif sample_method == 'random': + indices = np.random.choice(total_frames, num_frames, replace=False) + indices = sorted(indices) + else: + raise ValueError(f"Sampling method {sample_method} not supported.") + + sampled_frames = [] + + for idx in indices: + frame = frames[idx] + if transform: + frame = transform(frame, resolution=resolution) + sampled_frames.append(frame) + + return sampled_frames + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_v2s) + total_batch_size = ( + (config.training.batch_size_v2s) + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>, <|v2s|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + + # Load vq model for speech + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + # Resize Vocab size for Audio Modality + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.omada.new_vocab_size}") + + if resized_vocab_size == config.model.omada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + video_speech_dataset = VideoSpeechDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + openvid1m_path = "/home/work/AIDAS/data/video/openvid1m/video", + dataset_name = "openvidspeech", + sample_method="uniform", + num_frames=8, + ) + + video_speech_dataloader = DataLoader( + video_speech_dataset, + batch_size= config.training.batch_size_v2s, + num_workers = dataset_config.num_workers, + collate_fn = collate_fn_video_speech + ) + + num_update_steps_per_epoch = math.ceil( + config.training.batch_size_v2s / config.training.gradient_accumulation_steps + ) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + # weight_map=None, + # load_state_dict_fn="safetensors" + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + vq_model_image.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def safe_audio_encode(audio_path: str, flow_name: str): + try: + tokens = vq_model_audio.encode(audio_path) + return tokens, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, list[str]], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + seed=seed + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + @torch.no_grad() + def prepare_inputs_and_labels_for_chat_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm_chat') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_mask.bool()] = input_ids_lm[prompt_mask.bool()] + masked_indices = noisy_batch == mask_id + answer_lengths_lm = torch.sum((1 - prompt_mask), dim=-1, keepdim=True) + answer_lengths_lm = answer_lengths_lm.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_lm, p_mask, answer_lengths_lm + + @torch.no_grad() + def prepare_inputs_and_labels_for_v2s( + input_ids_v2s: torch.Tensor, # [B, L], long + prompt_masks: torch.Tensor, # [B, L], 1=prompt, 0=answer + labels_v2s: torch.Tensor, # [B, L], long + eps: float = 1e-3 + ): + device = input_ids_v2s.device + b, l = input_ids_v2s.shape + + # per-sample masking prob in [eps, 1) + p_mask = eps + (1.0 - eps) * torch.rand(b, device=device) # [B] + p_mask = p_mask.unsqueeze(1).expand(b, l) # [B, L] + + # mask only the answer region (prompt_masks==0) + rand_mat = torch.rand((b, l), device=device) + answer_region = (prompt_masks == 0) + masked_indices = (rand_mat < p_mask) & answer_region + + noisy_batch = input_ids_v2s.clone() + noisy_batch[masked_indices] = mask_id + + # answer lengths (broadcasted to [B, L] as your model expects) + answer_lengths = (prompt_masks == 0).sum(dim=-1, keepdim=True) # [B, 1] + answer_lengths = answer_lengths.expand(b, l) # [B, L] + + return noisy_batch.long(), labels_v2s.long(), p_mask, answer_lengths.long() + + + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + video_caption_table = wandb.Table(columns=["step", "video_filename", "caption"]) + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + model.train() + + batch_bar = tqdm( + video_speech_dataloader, + desc=f"Epoch {epoch + 1}/{num_train_epochs}", + position=1, + leave=False, + disable=not accelerator.is_main_process, + ) + + for batch in batch_bar: + # for loss calculation + batch_size_vid = batch["video"].shape[0] + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for captioning/multimodal understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if "llava" in config.dataset.und_type: + raise NotImplementedError + + else: + video_tensor, speech_paths = batch["video"], batch["speech"] # video is given as tensor and speech is given as paths + offset = speech_vocab_start + + failure_messages = [] + step_failed = False + + # encode video first + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + for video in video_tensor: # each video is (T, C, H, W) + video_token = vq_model_image.get_code(video) # (T, D) + # each video is tokenized into (T, D) + video_token = video_token + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + + # encode speech and generate inputs + all_audio_tokens = [] + for path in speech_paths: + tokens, err = safe_audio_encode(path, "v2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if step_failed: + print("Step has been failed!!!!!!!!") + continue + + # generate text prompt + # since there's no text input, only text is the prompt + prompt = random.choice(V2S_INSTRUCTION) + text_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for _ in video_tokens + ] + + # get input ids, prompt masks, and labels + print(f"shape of each input: {video_tokens.shape}, {len(text_prompt)},{all_audio_tokens[0]}") + input_ids_v2s, prompt_masks_v2s, labels_v2s = uni_prompting(( + video_tokens, + text_prompt, + all_audio_tokens + ), + 'v2s' + ) + input_ids_v2s, labels_v2s, p_mask_v2s, answer_lengths_v2s = prepare_inputs_and_labels_for_v2s( + input_ids_v2s, prompt_masks_v2s, labels_v2s + ) + + input_ids_v2s = input_ids_v2s.to(accelerator.device, non_blocking=True) + + input_ids = input_ids_v2s + labels = labels_v2s + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids.tolist())) + logger.info("Input ids Shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels.tolist())) + + + logits, loss_v2s = model.forward_v2s( + input_ids=input_ids, + labels=labels, + batch_size_v2s=batch_size_vid, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_v2s=p_mask_v2s, + answer_lengths=answer_lengths_v2s, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + ) + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss_v2s = accelerator.gather(loss_v2s.repeat(config.training.batch_size_v2s)).mean() + loss = loss_v2s + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "step_loss_v2s": avg_loss_v2s.item(), + "lr": lr_scheduler.get_last_lr()[0], + # "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_v2s: {avg_loss_v2s.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + + + global_step += 1 + # Stop training if max steps is reached + if global_step >= config.training.max_train_steps: + break + # End for + + if accelerator.is_main_process: + understanding_video( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + wandb_table=video_caption_table + ) + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +@torch.no_grad() +def understanding_video( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + wandb_table=None, +): + logger.info("Understanding videos..") + model.eval() + + video_image_root = "/home/work/AIDAS/video/demo" + + file_list = os.listdir(video_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.mp4'))] + responses = ['' for i in range(len(file_list))] + videos = [] + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + for i, file_name in enumerate(file_list): + + batch_size = 1 + + video_path = os.path.join(video_image_root, file_name) + video_frames = load_video_mp4( + video_path=video_path, + num_frames=8, + sample_method="uniform", + resolution=config.dataset.preprocessing.resolution, + transform=image_transform, + ) + + frames = torch.stack(video_frames, dim=0).to(device) # (T, C, H, W) + video_tokens = vq_model_image.get_code(frames) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(batch_size, -1) # (T * D) + + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this video in detail." +'<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|v2s|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + video_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate(input_ids, max_new_tokens=config.dataset.preprocessing.max_seq_length, steps=config.dataset.preprocessing.max_seq_length // 2, block_length=config.dataset.preprocessing.max_seq_length // 4) + # output_ids = torch.stack(output_ids).squeeze()[None] + + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + + model.train() + + # Log result + + for file_name, caption in zip(file_list, responses): + wandb_table.add_data(global_step, file_name, caption) + logger.info(f"step: {global_step}, video: {file_name}, caption: {caption}") + + accelerator.log({"video_captioning":wandb_table}, step=global_step) + # wandb.log({"video_captioning": wandb_table}, step=global_step) + # print(f"[step {global_step}] wandb_table rows = {len(wandb_table.data)}") + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_mmada_v2t.py b/MMaDA/training/train_mmada_v2t.py new file mode 100644 index 0000000000000000000000000000000000000000..ec9d5a5ba73a0620f35ecfd58ac86f01bacf827f --- /dev/null +++ b/MMaDA/training/train_mmada_v2t.py @@ -0,0 +1,998 @@ +# Copyright 2025 MMaDA Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +import traceback +import pandas as pd +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import cv2 +import glob +import random +from tqdm import tqdm +from pathlib import Path +from typing import Union +import csv +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting + +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +V2T_INSTRUCTION = ["Describe the video in detail.", + + "Please provide a detailed description of the video.", + + "What is happening in the video?", + + "Describe the content of the video in detail.",] + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + pass + # return VQ_16 + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_video_caption(batch): + frame_list = [] + input_ids_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + input_ids_list.append(item['caption']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + #input_ids = torch.stack(input_ids_list, dim=0) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "captions": input_ids_list # input_ids (B, seq_len) + } + +class VideoCaptionDataset(Dataset): + def __init__( + self, + transform, + tokenizer, + max_seq_length: int, + + resolution: int = 256, + panda70m_path = "/home/work/AIDAS/data/video/panda70m/panda70m_training_2m", + openvid1m_path = "/home/work/AIDAS/data/video/openvid1m/video", + webvid10m_path = "/home/work/AIDAS/data/video/webvid10m", + + dataset_name = "openvid1m", + + sample_method='uniform', + num_frames: int = 8, + vq_model=None, + ): + + available_datasets = ['panda70m', 'openvid1m', 'webvid10m'] + if dataset_name not in available_datasets: + raise ValueError(f"Invalid dataset name: {dataset_name}. Available datasets: {available_datasets}") + + self.max_seq_length = max_seq_length + self.transform = transform + self.vq_model = vq_model + self.tokenizer = tokenizer + self.resolution = resolution + self.sample_method = sample_method + self.dataset_name = dataset_name + self.num_frames = num_frames + self.caption_prompt = V2T_INSTRUCTION + self.caption_prompt = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in self.caption_prompt] + + self.webvid10m_path = webvid10m_path + self.dataset_length = 0 + + if dataset_name == 'panda70m': + self.vid_data = self._collect_panda70m(panda70m_path) + self.dataset_root = panda70m_path + elif dataset_name == 'webvid10m': + self.vid_data = self._collect_webvid10m(webvid10m_path) + self.dataset_root = webvid10m_path + elif dataset_name == 'openvid1m': + self.vid_data = self._collect_openvid1m(openvid1m_path) + self.dataset_root = openvid1m_path + else: + raise ValueError(f"Invalid dataset name: {dataset_name}. Available datasets: panda70m, webvid10m") + + def _get_caption_prompt(self): + """ + Get a random caption prompt from the list of caption prompts. + """ + return np.random.choice(self.caption_prompt) + + def _tokenize(self, text): + if self.tokenizer is not None: + input_ids = self.tokenizer( + text, + truncation=True, + max_length=2 * self.max_seq_length, + padding=False, + return_tensors="pt" + )[0] + + if len(input_ids) > self.max_seq_length: + return None + else: + return input_ids + else: + raise ValueError("Tokenizer is not provided.") + + def _collect_webvid10m(self, root_path): + + print("Loading videos from WebVid10m dataset...") + csv_path = osp.join(root_path, "webvid-10M-train.csv") + + webvid_pd = pd.read_csv(csv_path) + self.dataset_length = len(webvid_pd) + print(f"{len(webvid_pd)} videos has been loaded.") + + return webvid_pd + + def _collect_panda70m(self, root_path): + video_caption_pairs = [] + subdirs = sorted(os.listdir(root_path)) + + print("Loading videos from panda70m dataset...") + for subdir in subdirs: + full_subdir = os.path.join(root_path, subdir) + if not os.path.isdir(full_subdir): + continue + + video_paths = glob(os.path.join(full_subdir, "*.mp4")) + for video_path in video_paths: + caption_path = video_path.replace(".mp4", ".txt") + if os.path.exists(caption_path): + with open(caption_path, 'r') as f: + caption = f.read().strip() + prompt = self._get_caption_prompt() + video_caption_pairs.append({ + "video": video_path, + "caption": prompt + caption + }) + print(f"{len(video_caption_pairs)} videos has been loaded.") + + return video_caption_pairs + + def _collect_openvid1m(self, root_path): + csv_path = osp.join(root_path, "OpenVid-1M.csv") + openvid_pd = pd.read_csv(csv_path) + self.dataset_length = len(openvid_pd) + print(f"{len(openvid_pd)} videos has been loaded.") + + return openvid_pd + + def __len__(self): + return self.dataset_length + + def __getitem__(self, idx): + max_try_count = 5 + + for try_count in range(max_try_count): + try: + data = self._sample_data(idx) + if data is not None: + break + except Exception as e: + print(f"Error loading data: {e}") + print(traceback.format_exc()) + idx = random.randint(0, self.dataset_length - 1) + if try_count == max_try_count - 1: + raise e + + return { + "video": data["video"], # torch tensor (T, C, H, W) + "caption": data["caption"] # input_ids (seq_len) + } + + def _sample_data_webvid10m(self): + store_path = osp.join(self.webvid10m_path, "video_store") + + row = self.video_caption_pairs['webvid10m'].sample(1).iloc[0] + video_id = str(row["videoid"]) + url = row["contentUrl"] + caption = row["name"] + + video_path = osp.join(store_path, f"{video_id}.mp4") + if not osp.exists(video_path): # not downloaded yet + download_video_url(url, video_path) + + # print(video_id) + # print(_whoami_str()) + + return video_path, caption + + def _sample_data(self, idx): + if self.dataset_name == 'webvid10m': + # currently randomly sample from the dataset + video_path, caption = self._sample_data_webvid10m() + elif self.dataset_name == 'panda70m': + raise NotImplementedError("Panda70m is not implemented yet.") + # video_path, caption = self._sample_data_panda70m() + elif self.dataset_name == 'openvid1m': + data_row = self.vid_data.iloc[idx] + video_path = osp.join(self.dataset_root, "video", data_row["video"]) + caption = data_row["caption"] + print(f"{idx}, {data_row['video']}, {caption}") + else: + raise ValueError(f"Invalid dataset name: {self.dataset_name}. Available datasets: panda70m, webvid10m, openvid1m") + + frames = load_video_mp4( + video_path=video_path, + sample_method=self.sample_method, + num_frames=self.num_frames, + resolution=self.resolution, + transform=self.transform + ) + + return { + "video": frames, # torch tensor (T, C, H, W) + "caption": caption # input_ids (seq_len) + } + +def download_video_url(url: str, save_path, timeout=10, max_retries=3) -> bool: + for attempt in range(1, max_retries + 1): + try: + with requests.get(url, stream=True, timeout=timeout) as r: + r.raise_for_status() + with open(save_path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + return True # Success + + except Exception as e: + print(f"[Attempt {attempt}/{max_retries}] Download failed: {e}") + if attempt < max_retries: + sleep_time = 2 ** (attempt - 1) # exponential backoff: 1,2,4,8,... + time.sleep(sleep_time) + else: + return False # all attempts failed + + return False + +def load_video_mp4( + video_path, + sample_method='uniform', + num_frames: int = 8, + resolution: int = 256, + transform=None, + ): + """ + load video frames and then return it as a list of images + + args: + video_path: path to the video file + sample: sampling method, 'uniform' or 'random' + num_frames: number of frames to sample from the video + return: video frames as a list of images + """ + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Could not open video file {video_path}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(Image.fromarray(frame)) + + cap.release() + + total_frames = len(frames) + + if total_frames < num_frames: + raise ValueError(f"Video {video_path} has less than 8 frames, got {total_frames} frames.") + + if sample_method == 'uniform': + indices = np.linspace(0, total_frames - 1, num_frames).astype(int) + elif sample_method == 'random': + indices = np.random.choice(total_frames, num_frames, replace=False) + indices = sorted(indices) + else: + raise ValueError(f"Sampling method {sample_method} not supported.") + + sampled_frames = [] + + for idx in indices: + frame = frames[idx] + if transform: + frame = transform(frame, resolution=resolution) + sampled_frames.append(frame) + + return sampled_frames + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_v2t) + total_batch_size = ( + (config.training.batch_size_v2t) + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", "<|v2t|>" + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + if config.model.vq_model.get("pretrained_model_path", None): + vq_model = vq_model().to(accelerator.device) + state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] + vq_model.load_state_dict(state_dict) + else: + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + vq_model.eval() + vq_model.requires_grad_(False) + + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + video_captioning_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + dataset_name = "llavavid", + sample_method="uniform", + num_frames=8, + ) + + video_captioning_dataloader = DataLoader( + video_captioning_dataset, + batch_size=config.training.batch_size_v2t, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_video_caption, + ) + + num_update_steps_per_epoch = math.ceil( + config.training.batch_size_v2t / config.training.gradient_accumulation_steps + ) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + # weight_map=None, + # load_state_dict_fn="safetensors" + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + vq_model.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, list[str]], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + seed=seed + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + @torch.no_grad() + def prepare_inputs_and_labels_for_chat_text( + texts: Union[str, list[str]], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm_chat') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_mask.bool()] = input_ids_lm[prompt_mask.bool()] + masked_indices = noisy_batch == mask_id + answer_lengths_lm = torch.sum((1 - prompt_mask), dim=-1, keepdim=True) + answer_lengths_lm = answer_lengths_lm.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_lm, p_mask, answer_lengths_lm + + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + video_caption_table = wandb.Table(columns=["step", "video_filename", "caption"]) + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + model.train() + + batch_bar = tqdm( + video_captioning_dataloader, + desc=f"Epoch {epoch + 1}/{num_train_epochs}", + position=1, + leave=False, + disable=not accelerator.is_main_process, + ) + + for batch in batch_bar: + # for loss calculation + batch_size_vid = batch["video"].shape[0] + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for captioning/multimodal understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if "llava" in config.dataset.und_type: + raise NotImplementedError + + else: + video_tensor, texts_vid = batch["video"], batch["captions"] + + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + for video in video_tensor: # each video is (T, C, H, W) + video_token = vq_model.get_code(video) # (T, D) + # each video is tokenized into (T, D) + video_token = video_token + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + input_ids_vid, prompt_masks, labels_vid = uni_prompting((video_tokens, texts_vid), 'v2t') + + ( + input_ids_vid, + labels_vid, + p_mask_vid, + answer_lengths + ) = prepare_inputs_and_labels_for_mmu(input_ids_vid, prompt_masks, labels_vid) + + input_ids_vid = input_ids_vid.to(accelerator.device, non_blocking=True) + + input_ids = input_ids_vid + labels = labels_vid + + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Input ids Shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels)) + + logits, loss_v2t = model.forward_v2t( + input_ids=input_ids, + labels=labels, + batch_size_v2t=batch_size_vid, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_v2t=p_mask_vid, + answer_lengths=answer_lengths, + ) + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss_v2t = accelerator.gather(loss_v2t.repeat(config.training.batch_size_v2t)).mean() + loss = loss_v2t + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "step_loss_v2t": avg_loss_v2t.item(), + "lr": lr_scheduler.get_last_lr()[0], + # "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss_v2t: {avg_loss_v2t.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + + + global_step += 1 + # Stop training if max steps is reached + if global_step >= config.training.max_train_steps: + break + # End for + + if accelerator.is_main_process: + understanding_video( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step + 1, + wandb_table=video_caption_table + ) + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +@torch.no_grad() +def understanding_video( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + wandb_table=None, +): + logger.info("Understanding videos..") + model.eval() + + video_image_root = "/home/work/AIDAS/video/demo" + + file_list = os.listdir(video_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.mp4'))] + responses = ['' for i in range(len(file_list))] + videos = [] + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + for i, file_name in enumerate(file_list): + + batch_size = 1 + + video_path = os.path.join(video_image_root, file_name) + video_frames = load_video_mp4( + video_path=video_path, + num_frames=8, + sample_method="uniform", + resolution=config.dataset.preprocessing.resolution, + transform=image_transform, + ) + + frames = torch.stack(video_frames, dim=0).to(device) # (T, C, H, W) + video_tokens = vq_model.get_code(frames) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(batch_size, -1) # (T * D) + + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this video in detail." +'<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|v2t|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + video_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate(input_ids, max_new_tokens=config.dataset.preprocessing.max_seq_length, steps=config.dataset.preprocessing.max_seq_length // 2, block_length=config.dataset.preprocessing.max_seq_length // 4) + # output_ids = torch.stack(output_ids).squeeze()[None] + + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + + model.train() + + # Log result + + for file_name, caption in zip(file_list, responses): + wandb_table.add_data(global_step, file_name, caption) + logger.info(f"step: {global_step}, video: {file_name}, caption: {caption}") + + accelerator.log({"video_captioning":wandb_table}, step=global_step) + # wandb.log({"video_captioning": wandb_table}, step=global_step) + # print(f"[step {global_step}] wandb_table rows = {len(wandb_table.data)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MMaDA/training/train_mmada_v2t_inst.py b/MMaDA/training/train_mmada_v2t_inst.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce15a057f385957aaa6ac5979d5e89679a4ed09 --- /dev/null +++ b/MMaDA/training/train_mmada_v2t_inst.py @@ -0,0 +1,1809 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +## NOTE: hide ffmpeg error logs +# but why does so much error logs appear? idk +import warnings, os +import cv2 + +os.environ["FFMPEG_LOG_LEVEL"] = "error" +warnings.filterwarnings("ignore") + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import cv2 +import glob +import random +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union +import csv +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset, MixedSpeechTextDataset, load_video_mp4, VideoCaptionDataset, S2T_INSTRUCTION, T2S_INSTRUCTION, V2T_INSTRUCTION +# import librosa + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + + + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } + +def collate_fn_video_caption(batch): + + batch = [item for item in batch if item is not None] + if len(batch) == 0: + return None + + frame_list = [] + input_ids_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + input_ids_list.append(item['caption']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "captions": input_ids_list # input_ids (B, seq_len) + } + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + for item in batch: + path = item['audio_path'] + tokens = vq_model_audio.encode(path) + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(32)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation"): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + eos_token = uni_prompting.text_tokenizer.eos_token + eos_marker = eos_token if eos_token is not None else "" + for i in range(len(decoded_texts)): + full_text = decoded_texts[i] + eos_idx = full_text.find(eos_marker) + cleaned_text = full_text[:eos_idx] if eos_idx != -1 else full_text + cleaned_text = cleaned_text.replace(eos_marker, "").strip() + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": cleaned_text, + }) + + if not local_results: + logger.warning("S2T evaluation produced no results.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["decoded_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + + samples_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for idx, res in enumerate(local_results): + sample_id = res.get("sample_id", idx) + samples_table.add_data(sample_id, res["gt_text"], res["decoded_text"]) + + accelerator.log({"eval/s2t_samples": samples_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for T2S: user prompt + text + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 384 - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=5.0, + temperature=1.0, + timesteps=50, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=383, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + if not local_results: + logger.warning("Skipping T2S evaluation logging because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """Text-to-speech evaluation using the MMU-style block refinement decoder.""" + + if not accelerator.is_main_process: + return + + logger.info("***** Running T2S Evaluation (MMU-style decoder) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as exc: + logger.error(f"Failed to load T2S dataset or Whisper model for MMU-style eval: {exc}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_mmu") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + + codebook_size = config.model.omada.codebook_size + speech_vocab_size = 4096 + + for batch in tqdm(t2s_dataloader, desc="T2S MMU Eval"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + + batch_size = len(prompts) + speech_token_length = 384 - 1 + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + output_ids = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=speech_token_length, + steps=384 - 1, + block_length=384 - 1, + temperature=1.0, + cfg_scale=3.0, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=speech_vocab_size, + ) + + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + if isinstance(gen_speech_tokens, torch.Tensor): + gen_speech_tokens = gen_speech_tokens.detach().cpu() + + token_list = gen_speech_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]} (MMU eval). Skipping.") + continue + + speech_unit_str = " ".join(map(str, token_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}_mmu.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + try: + vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + except Exception as exc: + logger.error(f"Decoding failed for sample {sample_ids[i]} (MMU eval): {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": output_wav_path, + }) + + if not local_results: + logger.warning("Skipping T2S MMU-style evaluation because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S (MMU-style) Final WER: {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_mmu_like_wer": wer, + "eval/t2s_mmu_like_word_errors": errors, + "eval/t2s_mmu_like_total_words": words, + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_mmu_like_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Text-to-Speech (fixed-length) evaluation: + - Input prompt contains SOA + [MASK]*L + EOA (EOA is injected, not predicted) + - The model only fills VQ codes for exactly L positions (no EOA/EOS prediction) + - Generated audio is transcribed by Whisper; we report WER + """ + if not accelerator.is_main_process: + return + logger.info("***** Running T2S (fixed-length) Evaluation *****") + unwrapped = accelerator.unwrap_model(model) + unwrapped.eval() + + # Load eval dataset and Whisper model + try: + ds_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3", + device=accelerator.device + ) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load dataset or Whisper model: {e}") + return + + # Directory for saving generated audio files of this evaluation step + out_dir = os.path.join( + "/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_fixed" + ) + os.makedirs(out_dir, exist_ok=True) + + eval_ds = T2SEvalDataset(ds_raw) + loader = DataLoader(eval_ds, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + for batch in tqdm(loader, desc="T2S Fixed Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for fixed-length T2S + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 256 - 2 # exclude and + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_fixed_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + # Core generation call: + # - predict_eoa=False prevents EOA/EOS prediction; only VQ codes are generated + outputs = unwrapped.t2s_fixed_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=1.5, + temperature=1.0, + timesteps=24, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=speech_token_length, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode generated VQ codes → waveform via the speech tokenizer, then ASR with Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + gen_rel = outputs[i] # relative VQ ids in [0..4095] + id_list = gen_rel.tolist() + + if not id_list: + logger.warning(f"[fixed] Empty tokens for {sample_ids[i]}; skipping.") + continue + + # Convert to the speech-unit string format expected by the decoder + unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")]) + + # Synthesize audio and run Whisper + fname = f"process_{accelerator.process_index}_{sample_ids[i]}_fixed.wav" + wav_path = os.path.join(out_dir, fname) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + _ = vq_model_audio.decode( + speech_unit_for_decode, + condition=condition, + output_wav_file=wav_path + ) + asr = whisper_pipe(wav_path, generate_kwargs={"language": "english"}) + whisper_text = asr.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": wav_path + }) + + if not local_results: + logger.warning("Skipping T2S fixed evaluation logging because no samples were generated.") + return + + gt_list = [r["gt_text"] for r in local_results] + pred_list = [r["whisper_text"] for r in local_results] + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Fixed WER: {wer:.4f} | Errors: {errors} | Words: {words}") + + accelerator.log({ + "eval/t2s_fixed_wer": wer, + "eval/t2s_fixed_errors": errors, + "eval/t2s_fixed_words": words + }, step=global_step) + + table = wandb.Table(columns=["ID", "GT", "ASR", "Audio"]) + for r in local_results[:8]: + table.add_data( + r["sample_id"], + r["gt_text"], + r["whisper_text"], + wandb.Audio(r["audio_path"], caption=r["whisper_text"]) + ) + accelerator.log({"eval/t2s_fixed_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + if accelerator.is_main_process: + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + if accelerator.is_main_process: + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) + total_batch_size = ( + (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + # Resize Vocab size for Audio Modality + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.omada.new_vocab_size}") + + if resized_vocab_size == config.model.omada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size = ( + (config.training.batch_size_t2s + config.training.batch_size_s2t +config.training.batch_size_v2t) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Video Dataset + video_captioning_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + dataset_name = 'llavavid', + num_frames=8, + ) + + sampler_v2t = DistributedSampler( + video_captioning_dataset, + shuffle=True, # Should be true for training + drop_last=True + ) + + train_dataloader_v2t = DataLoader( + video_captioning_dataset, + batch_size=config.training.batch_size_v2t, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_video_caption, + sampler = sampler_v2t, + drop_last=True, + ) + + # Speech Dataset + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + logger.info(f"Dataset Prepared.") + + # Use distinct DistributedSamplers for each speech dataloader to avoid iterator interference + if accelerator.num_processes > 1: + sampler_s2t = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + sampler_t2s = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + else: + sampler_s2t = None + sampler_t2s = None + + train_dataloader_s2t = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_s2t, + shuffle=False, + sampler=sampler_s2t, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + train_dataloader_t2s = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_t2s, + shuffle=False, + sampler=sampler_t2s, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + + # Combine these dataloaders into a single iterable model + iterables = { + "v2t_flow": train_dataloader_v2t, + "t2s_flow": train_dataloader_t2s, + "s2t_flow": train_dataloader_s2t, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + # s2t + total_batch_size_s2t = config.training.batch_size_s2t * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_s2t = math.ceil(len(dataset_sm) / total_batch_size_s2t) + + # t2s + total_batch_size_t2s = config.training.batch_size_t2s * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_t2s = math.ceil(len(dataset_sm) / total_batch_size_t2s) + + # v2t + total_batch_size_v2t = (config.training.batch_size_v2t * accelerator.num_processes * config.training.gradient_accumulation_steps) + num_update_steps_per_epoch_v2t = math.ceil(len(video_captioning_dataset) / total_batch_size_v2t) + + + # Calculate num_train_epochs + num_update_steps_per_epoch = max(num_update_steps_per_epoch_s2t, num_update_steps_per_epoch_t2s, num_update_steps_per_epoch_v2t) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + logger.info(f"len of speech: {len(dataset_sm)}") + logger.info(f"len of video: {len(video_captioning_dataset)}") + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def safe_audio_encode(audio_path: str, flow_name: str): + try: + tokens = vq_model_audio.encode(audio_path) + return tokens, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + def safe_video_get_code(video_tensor_sample: torch.Tensor, sample_index: int): + try: + video_token = vq_model_image.get_code(video_tensor_sample) + return video_token, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] v2t video encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + # Ensure all samplers reshuffle in a rank-consistent way each epoch + try: + if isinstance(sampler_v2t, DistributedSampler): + sampler_v2t.set_epoch(epoch) + if accelerator.num_processes > 1: + if sampler_s2t is not None: + sampler_s2t.set_epoch(epoch) + if sampler_t2s is not None: + sampler_t2s.set_epoch(epoch) + except Exception: + pass + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + batch_size_t2i = 0 + batch_size_lm = 0 + batch_size_mmu = 0 + + # Synchronize skip decision across all ranks to avoid collective mismatches + local_skip = 1 if (batch is None or batch.get("v2t_flow") is None) else 0 + try: + skip_tensor = torch.tensor(local_skip, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction='sum') + should_skip = skip_sum.item() > 0 + except Exception: + # Fallback if reduce isn't available for any reason + should_skip = local_skip == 1 + + if should_skip: + if accelerator.is_main_process and local_skip: + logger.warning(f"Skipping step {global_step} (batch is None or v2t_flow missing) [synced]") + continue + + batch_size_v2t = len(batch["v2t_flow"]["video"]) + batch_size_t2s = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + + logger.info(f"batch_size_v2t: {batch_size_v2t}, batch_size_t2s: {batch_size_t2s}, batch_size_s2t: {batch_size_s2t}" ) + + # print(f"Rank {accelerator.process_index} loading data...") + # print(batch["s2t_flow"]["audio_path"]) + # print(batch["v2t_flow"]['captions']) + + audio_paths_s2t, texts_s2t = batch["s2t_flow"]["audio_path"], batch["s2t_flow"]["text"] + audio_paths_t2s, texts_t2s = batch["t2s_flow"]["audio_path"], batch["t2s_flow"]["text"] + offset = speech_vocab_start + video_tensor, texts_vid = batch["v2t_flow"]["video"], batch["v2t_flow"]["captions"] + + failure_messages = [] + step_failed = False + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for video understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + prompt_v2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in V2T_INSTRUCTION] + + ### Check if the dataset if instruction-tuning dataset + # In case of the video-instruction tuning dataset + is_vid_inst = False + vid_inst_prompt, vid_inst_answer = None, None + + # detect if this is video-instruction style (list of [prompt, answer]) + if isinstance(texts_vid[0], (list, tuple)) and isinstance(texts_vid[0][0], dict): + is_vid_inst = True + vid_inst_prompt = [] + vid_inst_answer = [] + for conv in texts_vid: + human = conv[0].get("value", "").replace("\n", "") + gpt = conv[1].get("value", "") + vid_inst_prompt.append(human) + vid_inst_answer.append(gpt) + + # original video forward pass + for vid_idx, video in enumerate(video_tensor): # each video is (T, C, H, W) + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + video_token = tokens + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + + if not step_failed: + if is_vid_inst: + # dataset has its own prompts and answers + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{vid_inst_prompt[i]}<|eot_id|>" + f"<|start_header_id|>assistant<|end_header_id|>\n{vid_inst_answer[i]}" + for i in range(len(vid_inst_answer)) + ] + else: + # generic video captioning dataset + prompt_v2t_selected = random.choice(V2T_INSTRUCTION) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt_v2t_selected}<|eot_id|>" + f"<|start_header_id|>assistant<|end_header_id|>\n{text}" + for text in texts_vid + ] + + print(texts_with_prompt) + input_ids_vid, prompt_masks_vid, labels_vid = uni_prompting((video_tokens, texts_with_prompt), 'v2t') + + input_ids_vid, labels_vid, p_mask_vid, answer_lengths_vid = prepare_inputs_and_labels_for_mmu( + input_ids_vid, prompt_masks_vid, labels_vid + ) + + input_ids_vid = input_ids_vid.to(accelerator.device, non_blocking=True) + + print("\n--- [DEBUG: V2T tensors check] ---") + + # input_ids / labels ģš”ģ•½ + print(f"input_ids_vid: shape={input_ids_vid.shape}, dtype={input_ids_vid.dtype}, device={input_ids_vid.device}") + print(f"labels_vid: shape={labels_vid.shape}, dtype={labels_vid.dtype}") + print(f"p_mask_vid: shape={p_mask_vid.shape if p_mask_vid is not None else None}") + print(f"answer_lengths_vid: {answer_lengths_vid}") + + # sanity check for NaN / inf + if torch.isnan(input_ids_vid).any() or torch.isinf(input_ids_vid).any(): + print("āš ļø input_ids_vid contains NaN or Inf values!") + if torch.isnan(labels_vid).any() or torch.isinf(labels_vid).any(): + print("āš ļø labels_vid contains NaN or Inf values!") + + # token ģ¼ė¶€ė§Œ ķ™•ģø (첫 ģƒ˜ķ”Œė§Œ) + print("\n[Example token ids] input_ids_vid[0, :30] =", input_ids_vid[0].tolist()) + print("[Example label ids] labels_vid[0, :30] =", labels_vid[0].tolist()) + + # prompt alignment check + print("\nPrompt + Answer text example:") + if isinstance(texts_with_prompt, list): + print("User prompt:", texts_with_prompt[0][:300]) # 첫 300ģžė§Œ 볓기 + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_s2t: + tokens, err = safe_audio_encode(path, "s2t") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + prompt = random.choice(prompt_s2t) + + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + # Preserve trailing EOS tokens in s2t targets for explicit prediction. + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_t2s: + tokens, err = safe_audio_encode(path, "t2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + # Chat-style instruction formatting for T2S training + prompt = random.choice(prompt_t2s) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in texts_t2s + ] + + # input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + + failure_tensor = torch.tensor(1 if step_failed else 0, device=accelerator.device, dtype=torch.int32) + failure_sum = accelerator.reduce(failure_tensor, reduction='sum') + if failure_sum.item() > 0: + if accelerator.is_main_process and failure_messages: + for msg in failure_messages: + logger.warning(f"Skipping global step {global_step} due to preprocessing failure: {msg}") + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + # --------------------------------------------------------------------------------- + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + max_len = max( + input_ids_vid.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1] + ) + + # 3. Pad all tensors to the max_len + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + # --------------------------------------------------------------------------------- + + input_ids = torch.cat(( + input_ids_vid, + input_ids_s2t, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_vid, + labels_s2t, + labels_t2s + ), dim=0) + + # w/o texts and images + p_mask_lm = None + p_mask_mmu = None + answer_lengths_mmu = None + t2i_masks = None + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Input ids shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels)) + + # with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s, _, _ = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_s2t=batch_size_s2t, + batch_size_t2s=batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_s2t=p_mask_s2t, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + ) + + # Gather the losses across all processes for logging (use reduce to avoid shape mismatches) + # avg_loss_t2i = accelerator.reduce(loss_t2i, reduction='mean') + # avg_loss_lm = accelerator.reduce(loss_lm, reduction='mean') + # avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + + avg_loss_vid = accelerator.reduce(loss_vid, reduction='mean') + avg_loss_s2t = accelerator.reduce(loss_s2t, reduction='mean') + avg_loss_t2s = accelerator.reduce(loss_t2s, reduction='mean') + + # loss = (config.training.t2i_coeff * loss_t2i + + # config.training.lm_coeff * loss_lm + + # config.training.mmu_coeff * loss_mmu + + # config.training.vid_coeff * loss_vid + + # config.training.s2t_coeff * loss_s2t + + # config.training.t2s_coeff * loss_t2s) + + loss = (config.training.v2t_coeff * loss_vid + + config.training.s2t_coeff * loss_s2t + + config.training.t2s_coeff * loss_t2s) + + # HMM~~~~~ + avg_masking_rate = accelerator.reduce(p_mask_t2s.mean(), reduction='mean') + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + # "step_loss_t2i": avg_loss_t2i.item(), + # "step_loss_mmu": avg_loss_mmu.item(), + # "step_loss_lm": avg_loss_lm.item(), + "step_loss_vid": avg_loss_vid.item(), + "step_loss_s2t": avg_loss_s2t.item(), + "step_loss_t2s": avg_loss_t2s.item(), + "lr": lr_scheduler.get_last_lr()[0], + # "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + # f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + # f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + # f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Loss_vid: {avg_loss_vid.item():0.4f} " + f"Loss_s2t: {avg_loss_s2t.item():0.4f} " + f"Loss_t2s: {avg_loss_t2s.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if global_step == 0 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_mmada_v2t_tmp.py b/MMaDA/training/train_mmada_v2t_tmp.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae450203efc82593f282ce13f3a530e964802f0 --- /dev/null +++ b/MMaDA/training/train_mmada_v2t_tmp.py @@ -0,0 +1,393 @@ +# Copyright 2025 AIDAS Lab +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import shutil +import time +from pathlib import Path +from typing import Union, List + +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed + +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset + +from models import MAGVITv2, get_mask_schedule, MMadaModelLM, MMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "vq16": + return VQ_16 + else: + raise ValueError(f"model_type {model_type} not supported.") + +# +++++ V2T Dataset Definition +++++ +class VideoCaptionDataset(Dataset): + def __init__(self, data, tokenizer, image_transform, vq_model, uni_prompting, config, device='cuda', num_frames=8): + """ + data: [{"video_url": str, "caption": str}, ...] + """ + self.data = data + self.tokenizer = tokenizer + self.image_transform = image_transform + self.vq_model = vq_model + self.uni_prompting = uni_prompting + self.config = config + self.device = device + self.num_frames = num_frames + + def _load_video_from_url(self, url): + # download video from url + local_path = '/home/work/AIDAS/data/video' ## donwload path + with requests.get(url, stream=True) as r: + with open(local_path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + cap = cv2.VideoCapture(local_path) + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(Image.fromarray(frame)) + cap.release() + os.remove(local_path) + if len(frames) < self.num_frames: + raise ValueError(f"Too few frames: {len(frames)}") + indices = np.linspace(0, len(frames) - 1, self.num_frames).astype(int) + sampled_frames = [frames[i] for i in indices] + return sampled_frames + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + entry = self.data[idx] + video_url = entry['video_url'] + caption = entry['caption'] + + # 1. frame extract + frames = self._load_video_from_url(video_url) + frame_tensors = [self.image_transform(img, resolution=self.config.dataset.params.resolution).to(self.device) for img in frames] + frame_tensors = torch.stack(frame_tensors, dim=0) # (num_frames, 3, H, W) + # vq tokenize, concat + video_tokens = torch.cat([self.vq_model.get_code(ft.unsqueeze(0)) + len(self.uni_prompting.text_tokenizer) for ft in frame_tensors], dim=1) + + return { + 'video_tokens': video_tokens, # shape: (1, num_frames * seq_len) + 'caption': caption + } + + +def collate_fn(batch): + video_tokens = torch.cat([item['video_tokens'] for item in batch], dim=0) + captions = [item['caption'] for item in batch] + return {'video_tokens': video_tokens, 'captions': captions} +# ++++++++++++++++++++++++++++++++++ + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + ) + + total_batch_size_per_gpu = config.training.batch_size_i2i #### 바꿔야 할듯?? + total_batch_size = ( + config.training.batch_size_i2i + * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + + + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + logger.info("="*50) + logger.info(f"max_train_steps from config: {config.training.max_train_steps}") + logger.info("="*50) + + tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting( + tokenizer, + max_text_len=config.dataset.preprocessing.max_seq_length, + special_tokens=( + "<|sov|>", "<|eov|>", "<|eot|>", "<|sot|>", "<|mmu|>", "<|v2t|>" ### ģž˜ ėŖØė„“ź² ģŒ + ), + ignore_id=-100, + cond_dropout_prob=config.training.cond_dropout_prob, + use_reserved_token=True + ) + + + # VQ model for processing image into discrete tokens + vq_model = get_vq_model_class(config.model.vq_model.type) + if config.model.vq_model.get("pretrained_model_path", None): + vq_model = vq_model().to(accelerator.device) + state_dict = torch.load(config.model.vq_model.pretrained_model_path)['model'] + vq_model.load_state_dict(state_dict) + else: + vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(accelerator.device) + + vq_model.eval() + vq_model.requires_grad_(False) + model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps + # max_train_steps_for_scheduler = config.training.max_train_steps + + lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps * accelerator.num_processes + max_train_steps_for_scheduler = config.training.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps_for_scheduler, + num_training_steps=max_train_steps_for_scheduler, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + + # ----- dataloader ----- + dataset = load_dataset("adkfjlakdjflk;ajdflk;ajflk;jal;fjalk;sdfjal;jf", split="train") + train_data = [] + for sample in dataset: + train_data.append({ + "video_url": sample["video"], + "caption": sample["caption"] + }) + + train_dataset = VideoCaptionDataset(train_data, tokenizer, image_transform, vq_model, uni_prompting, config, device=device) + train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn) + + + + # ################################## + # # Prepare accelerator # + # ################################# + # logger.info("Preparing model, optimizer and dataloader") + # model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + # model, optimizer, train_dataloader, lr_scheduler + # ) + # lr_scheduler = get_scheduler( + # config.lr_scheduler.scheduler, + # optimizer=optimizer, + # num_training_steps=config.training.max_train_steps, + # num_warmup_steps=config.lr_scheduler.params.warmup_steps, + # min_lr_scale=config.lr_scheduler.params.min_lr_scale + # ) + + # vq_model.to(device=accelerator.device) + ## ģ“ģƒķ•œė°?? + + + model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler) + + for epoch in range(config.training.epochs): + model.train() + for step, batch in enumerate(tqdm(train_loader)): + video_tokens = batch['video_tokens'].to(device) + captions = batch['captions'] + + target = tokenizer(captions, return_tensors='pt', padding=True, truncation=True, max_length=config.dataset.preprocessing.max_seq_length).to(device) + labels = target['input_ids'] + attention_mask = target['attention_mask'] + + batch_size = video_tokens.shape[0] + input_ids = torch.cat([ + torch.full((batch_size,1), uni_prompting.sptids_dict['<|soi|>'], dtype=torch.long, device=device), + video_tokens, + torch.full((batch_size,1), uni_prompting.sptids_dict['<|eoi|>'], dtype=torch.long, device=device), + torch.full((batch_size,1), uni_prompting.sptids_dict['<|sot|>'], dtype=torch.long, device=device), + ], dim=1) + + outputs = model.v2t_forward( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels + ) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + if step % 10 == 0: + wandb.log({'loss': loss.item()}, step=epoch * len(train_loader) + step) + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MMaDA/training/train_omada_inst.py b/MMaDA/training/train_omada_inst.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4bbc05c9d32c7ba3ca960fcc06e281431197e9 --- /dev/null +++ b/MMaDA/training/train_omada_inst.py @@ -0,0 +1,4311 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import binascii +import hashlib +import os +import sys +import warnings +import subprocess +import tempfile +os.environ["FFMPEG_LOG_LEVEL"] = "error" +warnings.filterwarnings("ignore") + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import torch.nn.functional as F +import shutil +import time +import cv2 +import glob +import random +import contextlib +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union, Dict, Any, List, Iterator +from collections.abc import Sequence +import csv +import numpy as np +from PIL import Image +from io import BytesIO +from omegaconf import OmegaConf, DictConfig +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader +import torch.multiprocessing as mp + +try: + cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR) +except AttributeError: + warnings.filterwarnings("ignore", category=FutureWarning) + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import ( + SpeechTextDataset, + MixedSpeechTextDataset, + Speech2SpeechDataset, + TextImageInterleavedDataset, + load_video_mp4, + VideoCaptionDataset, + VideoSpeechDataset, + S2T_INSTRUCTION, + T2S_INSTRUCTION, + V2S_INSTRUCTION, + s2s_collate_fn, +) +# import librosa + +from training.data import ( + Text2ImageDataset, + HQEditX2IDataset, + CombinedX2IDataset, + HFInstructionTextDataset, + TextToImage2MDataset, + OpenImageI2IDataset, +) +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +cv2.setNumThreads(0) +torch.set_num_threads(1) +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["NCCL_ASYNC_ERROR_HANDLING"]= "1" +os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" +os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]= "1" +os.environ["TORCH_NCCL_BLOCKING_WAIT"]= "1" + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def _broadcast_choice(value: int, accelerator: Accelerator, src: int = 0) -> int: + """Synchronize an integer choice across all ranks without relying on NCCL shared memory.""" + if accelerator.num_processes == 1: + return value + choice_tensor = torch.tensor( + [value if accelerator.process_index == src else 0], + device=accelerator.device, + dtype=torch.int32, + ) + gathered = accelerator.gather(choice_tensor) + return int(gathered[src].item()) + + +def _patch_shared_memory_tracker() -> None: + """Stop Python's resource_tracker from eagerly unlinking shared memory segments.""" + try: + from multiprocessing import resource_tracker + except ImportError: + return + + if getattr(resource_tracker, "_omada_shared_memory_patch", False): + return + + orig_register = resource_tracker.register + orig_unregister = resource_tracker.unregister + cleanup_funcs = getattr(resource_tracker, "_CLEANUP_FUNCS", {}) + orig_cleanup = cleanup_funcs.get("shared_memory") + + def _safe_register(name: str, rtype: str) -> None: + if rtype == "shared_memory": + return + orig_register(name, rtype) + + def _safe_unregister(name: str, rtype: str) -> None: + if rtype == "shared_memory": + return + orig_unregister(name, rtype) + + def _noop(*_args, **_kwargs) -> None: + return + + resource_tracker.register = _safe_register # type: ignore[assignment] + resource_tracker.unregister = _safe_unregister # type: ignore[assignment] + if orig_cleanup is not None: + cleanup_funcs["shared_memory"] = _noop + resource_tracker._omada_shared_memory_patch = True # type: ignore[attr-defined] + + +def _configure_multiprocessing() -> None: + """Configure torch multiprocessing to avoid shared-memory exhaustion on multi-worker loads.""" + try: + _patch_shared_memory_tracker() + except Exception as exc: # pragma: no cover - best-effort patching + logger.warning("Failed to apply shared memory tracker patch: %s", exc) + + try: + mp.set_sharing_strategy("file_descriptor") + except RuntimeError as exc: + logger.warning("Failed to set multiprocessing sharing strategy to 'file_descriptor': %s", exc) + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + 'audio_tokens': [item.get('audio_tokens') for item in batch], + } + + +def _empty_audio_batch() -> dict[str, list[Any]]: + """Utility to create an empty speech batch placeholder.""" + return { + "audio_path": [], + "text": [], + "audio_tokens": [], + } + + +def collate_fn_mmu_mult(batch): + return { + 'images': [item['images'] for item in batch], + 'text': [item['text'] for item in batch], + } + + +def collate_fn_x2i(batch): + t2i_texts: list[str] = [] + t2i_images: list[torch.Tensor] = [] + + i2i_prompts: list[str] = [] + i2i_source_images: list[torch.Tensor] = [] + i2i_target_images: list[torch.Tensor] = [] + + ref_image: Optional[torch.Tensor] = None + + has_i2i_sample = False + + for sample in batch: + input_prompt = sample.get("input_prompt") + output_prompt = sample.get("output_prompt") + edit_prompt = sample.get("edit_prompt") + inverse_prompt = sample.get("inverse_prompt") + input_image = sample.get("input_image") + output_image = sample.get("output_image") + + if isinstance(input_image, torch.Tensor) and ref_image is None: + ref_image = input_image + if isinstance(output_image, torch.Tensor) and ref_image is None: + ref_image = output_image + + has_edit_pair = ( + isinstance(input_image, torch.Tensor) + and isinstance(output_image, torch.Tensor) + and ( + (edit_prompt and edit_prompt.strip()) + or (inverse_prompt and inverse_prompt.strip()) + ) + ) + + if has_edit_pair: + has_i2i_sample = True + edit_candidates: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + if edit_prompt and edit_prompt.strip(): + edit_candidates.append((edit_prompt, input_image, output_image)) + if inverse_prompt and inverse_prompt.strip(): + edit_candidates.append((inverse_prompt, output_image, input_image)) + + if edit_candidates: + chosen_prompt, chosen_src, chosen_tgt = random.choice(edit_candidates) + i2i_prompts.append(chosen_prompt) + i2i_source_images.append(chosen_src) + i2i_target_images.append(chosen_tgt) + continue + else: + if input_prompt and isinstance(input_image, torch.Tensor): + t2i_texts.append(input_prompt) + t2i_images.append(input_image) + elif output_prompt and isinstance(output_image, torch.Tensor): + t2i_texts.append(output_prompt) + t2i_images.append(output_image) + + if has_i2i_sample: + # i2iź°€ ķ•˜ė‚˜ė¼ė„ ģžˆģœ¼ė©“ ģ“ė²ˆ ė°°ģ¹˜ėŠ” i2i ģ „ģš©ģœ¼ė”œ ģ‚¬ģš©ķ•˜ź³  t2iėŠ” 비움 + t2i_texts = [] + t2i_images = [] + + def stack_images(images: list[torch.Tensor]) -> torch.Tensor: + if images: + return torch.stack(images, dim=0) + if ref_image is not None: + c, h, w = ref_image.shape[-3:] + return torch.empty((0, c, h, w), dtype=ref_image.dtype) + return torch.empty((0, 3, 0, 0), dtype=torch.float32) + + return { + "t2i": { + "texts": t2i_texts, + "images": stack_images(t2i_images), + }, + "i2i": { + "prompts": i2i_prompts, + "source_images": stack_images(i2i_source_images), + "target_images": stack_images(i2i_target_images), + }, + } + + +def collate_fn_v2t(batch: list[dict[str, Any]]) -> Optional[dict[str, Any]]: + filtered = [sample for sample in batch if sample is not None] + if not filtered: + return None + video_tensors: list[torch.Tensor] = [] + captions: list[Any] = [] + for sample in filtered: + frames = sample.get("video") + if frames is None: + continue + frame_tensor = torch.stack(frames, dim=0) + video_tensors.append(frame_tensor) + captions.append(sample.get("caption")) + if not video_tensors: + return None + return { + "video": torch.stack(video_tensors, dim=0), + "captions": captions, + } + + +def collate_fn_v2s(batch: list[dict[str, Any]]) -> Optional[dict[str, Any]]: + filtered = [sample for sample in batch if sample is not None] + if not filtered: + return None + video_tensors: list[torch.Tensor] = [] + speech_entries: list[Any] = [] + for sample in filtered: + frames = sample.get("video") + speech_value = sample.get("speech") + if frames is None or speech_value is None: + continue + frame_tensor = torch.stack(frames, dim=0) + video_tensors.append(frame_tensor) + speech_entries.append(speech_value) + if not video_tensors: + return None + return { + "video": torch.stack(video_tensors, dim=0), + "speech": speech_entries, + } + +def collate_fn_video_multimodal(batch): + text_videos: list[torch.Tensor] = [] + text_captions: list = [] + speech_videos: list[torch.Tensor] = [] + speech_entries: list[Any] = [] + + for sample in batch: + if sample is None: + continue + text_sample = sample.get("text") + if isinstance(text_sample, dict) and text_sample.get("video") is not None: + frames = text_sample["video"] + frame_tensor = torch.stack(frames, dim=0) + text_videos.append(frame_tensor) + text_captions.append(text_sample["caption"]) + + speech_sample = sample.get("speech") + if isinstance(speech_sample, dict) and speech_sample.get("video") is not None: + frames = speech_sample["video"] + frame_tensor = torch.stack(frames, dim=0) + speech_videos.append(frame_tensor) + speech_entries.append(speech_sample["speech"]) + + output: Dict[str, Any] = {} + if text_videos: + output["text"] = { + "video": torch.stack(text_videos, dim=0), + "captions": text_captions, + } + if speech_videos: + output["speech"] = { + "video": torch.stack(speech_videos, dim=0), + "speech": speech_entries, + } + if not output: + return None + return output + + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + for item in batch: + audio_entry = item['audio_path'] + if isinstance(audio_entry, torch.Tensor): + tokens = audio_entry.cpu() + else: + tokens = vq_model_audio.encode(audio_entry).cpu() + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + +@ torch.no_grad() +def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., + cfg_scale=0., remasking='low_confidence', mask_id=126336, attention_mask=None): + ''' + Args: + model: Mask predictor. + prompt: A tensor of shape (B, L), where B is batch size. + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + ''' + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + batch_size = prompt.shape[0] + x = torch.full((batch_size, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x[:, :prompt.shape[1]] = prompt.clone() + + prompt_index = (x != mask_id) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = model(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = model(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + # print(confidence.shape) + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + return x + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + +def _resolve_mask_schedule(config): + schedule_cfg = getattr(config, "mask_schedule", None) + if isinstance(schedule_cfg, DictConfig): + schedule_name = getattr(schedule_cfg, "schedule", None) + params_cfg = getattr(schedule_cfg, "params", None) + elif isinstance(schedule_cfg, dict): + schedule_name = schedule_cfg.get("schedule") + params_cfg = schedule_cfg.get("params") + else: + schedule_name = None + params_cfg = None + if schedule_name is None: + schedule_name = config.training.get("mask_schedule", "cosine") + params = {} + if params_cfg is not None: + if isinstance(params_cfg, DictConfig): + params = OmegaConf.to_container(params_cfg, resolve=True) or {} + elif isinstance(params_cfg, dict): + params = dict(params_cfg) + else: + params = params_cfg + if not isinstance(params, dict): + params = {} + return get_mask_schedule(schedule_name, **params) +def _tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image: + image = torch.clamp((image_tensor.detach().cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) + array = (image.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) + return Image.fromarray(array) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2I EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +@torch.no_grad() +def evaluate_t2i(model, vq_model_image, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2I Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + prompts_file = "/home/work/AIDAS/MMaDA/validation_prompts/quantative.txt" + if not prompts_file: + logger.warning("No validation prompts file configured. Skipping T2I evaluation.") + return + prompts_path = Path(prompts_file) + if not prompts_path.is_absolute(): + prompts_path = Path.cwd() / prompts_path + if not prompts_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / prompts_file + if alt_path.exists(): + prompts_path = alt_path + try: + with open(prompts_path, "r", encoding="utf-8") as handle: + prompts = [line.strip() for line in handle if line.strip()] + except OSError as exc: + logger.warning(f"Failed to read validation prompts from '{prompts_path}': {exc}. Skipping T2I evaluation.") + return + if not prompts: + logger.warning("Validation prompts file is empty. Skipping T2I evaluation.") + return + max_samples = getattr(config.experiment, "eval_num_t2i_samples", 8) + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 8 + prompts = prompts[:max_samples] + mask_schedule = _resolve_mask_schedule(config) + mask_token_id = unwrapped_model.config.mask_token_id + seq_len = getattr(getattr(config.model, "omada", None), "num_vq_tokens", None) + if seq_len is None: + seq_len = getattr(unwrapped_model.config, "num_vq_tokens", None) + if seq_len is None: + logger.warning("Unable to determine image token sequence length. Skipping T2I evaluation.") + return + seq_len = int(seq_len) + device = accelerator.device + image_tokens = torch.full((len(prompts), seq_len), mask_token_id, dtype=torch.long, device=device) + input_ids, attention_mask = uni_prompting((prompts, image_tokens), 't2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids, uncond_attention_mask = None, None + cfg_scale = 0.0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + use_autocast = accelerator.device.type == "cuda" and accelerator.mixed_precision != "no" + autocast_ctx = torch.autocast("cuda", dtype=weight_dtype) if use_autocast else contextlib.nullcontext() + with autocast_ctx: + gen_token_ids = unwrapped_model.t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=3.5, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=15, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=seq_len, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, min=0, max=unwrapped_model.config.codebook_size - 1) + images = vq_model_image.decode_code(gen_token_ids) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images = images.permute(0, 2, 3, 1).cpu().numpy() * 255.0 + pil_images = [Image.fromarray(img.astype(np.uint8)) for img in images] + wandb_images = [wandb.Image(img, caption=prompt) for img, prompt in zip(pil_images, prompts)] + accelerator.log({"eval/t2i_samples": wandb_images}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ I2I EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +@torch.no_grad() +def evaluate_i2i(model, vq_model_image, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running I2I Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + dataset_cfg_raw = getattr(config.dataset, "params", {}) + resolution = 512 + + def _cfg_to_dict(cfg): + if cfg is None: + return None + if isinstance(cfg, dict): + return cfg + if isinstance(cfg, DictConfig): + return OmegaConf.to_container(cfg, resolve=True) + return cfg + + dataset_cfg = _cfg_to_dict(dataset_cfg_raw) or {} + + eval_datasets: list[Dataset] = [] + eval_source_names: list[str] = [] + + # HQ-Edit evaluation dataset (always attempt; mirrors training) + try: + hqedit_split = dataset_cfg.get("hqedit_split", "train") + hqedit_eval = HQEditX2IDataset(split=hqedit_split, resolution=resolution) + if len(hqedit_eval) > 0: + eval_datasets.append(hqedit_eval) + eval_source_names.append(f"HQ-Edit[{hqedit_split}]") + else: + logger.warning("HQ-Edit evaluation split '%s' is empty; skipping.", hqedit_split) + except Exception as exc: + logger.warning("Failed to build HQ-Edit evaluation dataset: %s", exc) + + # OpenImage evaluation dataset if configured + openimage_cfg = _cfg_to_dict(dataset_cfg.get("openimage_i2i")) + if openimage_cfg: + try: + openimage_eval = OpenImageI2IDataset( + resolution=resolution, + image_root=openimage_cfg.get("image_root"), + sft_jsonl=openimage_cfg.get("sft_jsonl"), + pref_jsonl=openimage_cfg.get("pref_jsonl"), + multi_turn_jsonl=openimage_cfg.get("multi_turn_jsonl"), + prefer_summarized_text=bool(openimage_cfg.get("prefer_summarized_text", True)), + pref_positive_only=bool(openimage_cfg.get("pref_positive_only", True)), + skip_missing=bool(openimage_cfg.get("skip_missing", True)), + max_samples_per_source=openimage_cfg.get("max_samples_per_source"), + max_total_samples=openimage_cfg.get("max_total_samples"), + seed=openimage_cfg.get("seed"), + ) + if len(openimage_eval) > 0: + eval_datasets.append(openimage_eval) + eval_source_names.append("OpenImage I2I") + else: + logger.warning("OpenImage I2I evaluation dataset is empty; skipping.") + except Exception as exc: + logger.warning("Failed to build OpenImage I2I eval dataset: %s", exc) + + if not eval_datasets: + logger.warning("No i2i evaluation dataset available. Skipping.") + return + + eval_dataset = ( + eval_datasets[0] if len(eval_datasets) == 1 else CombinedX2IDataset(eval_datasets) + ) + logger.info("Using I2I evaluation datasets: %s", ", ".join(eval_source_names)) + + max_samples = getattr(config.experiment, "eval_num_i2i_samples", 8) + + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 8 + num_samples = min(max_samples, len(eval_dataset)) + if len(eval_dataset) <= num_samples: + sample_indices = list(range(len(eval_dataset))) + else: + sample_indices = random.sample(range(len(eval_dataset)), num_samples) + samples = [eval_dataset[i] for i in sample_indices] + prompts = [] + original_tensors = [] + target_tensors = [] + for sample in samples: + prompts.append(sample.get("edit_prompt") or sample.get("output_prompt") or "") + original_tensors.append(sample["input_image"]) + target_tensors.append(sample["output_image"]) + original_images = torch.stack(original_tensors, dim=0).to(accelerator.device) + original_tokens = vq_model_image.get_code(original_images) + len(uni_prompting.text_tokenizer) + seq_len = original_tokens.shape[-1] + mask_token_id = unwrapped_model.config.mask_token_id + placeholder = torch.full((num_samples, seq_len), mask_token_id, dtype=torch.long, device=accelerator.device) + input_ids, attention_mask = uni_prompting((prompts, original_tokens, placeholder), 'i2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting( + ([''] * num_samples, original_tokens, placeholder), 'i2i_gen' + ) + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids, uncond_attention_mask = None, None + cfg_scale = 0.0 + mask_schedule = _resolve_mask_schedule(config) + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + use_autocast = accelerator.device.type == "cuda" and accelerator.mixed_precision != "no" + autocast_ctx = torch.autocast("cuda", dtype=weight_dtype) if use_autocast else contextlib.nullcontext() + with autocast_ctx: + gen_token_ids = unwrapped_model.i2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=3.5, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=15, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + seq_len=seq_len, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, min=0, max=unwrapped_model.config.codebook_size - 1) + generated_images = vq_model_image.decode_code(gen_token_ids) + generated_images = torch.clamp((generated_images + 1.0) / 2.0, min=0.0, max=1.0) + gen_images_pil = [Image.fromarray((img.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)) for img in generated_images] + source_pil = [_tensor_to_pil(tensor) for tensor in original_tensors] + target_pil = [_tensor_to_pil(tensor) for tensor in target_tensors] + log_resolution = getattr(config.experiment, "eval_image_log_resolution", 512) + wandb_images = [] + for prompt, src, pred, tgt in zip(prompts, source_pil, gen_images_pil, target_pil): + composite = Image.new('RGB', (log_resolution * 3, log_resolution)) + src_resized = src.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + pred_resized = pred.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + tgt_resized = tgt.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + composite.paste(src_resized, (0, 0)) + composite.paste(pred_resized, (log_resolution, 0)) + composite.paste(tgt_resized, (log_resolution * 2, 0)) + wandb_images.append(wandb.Image(composite, caption=f"Prompt: {prompt}")) + accelerator.log({"eval/i2i_samples": wandb_images}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running S2S Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + dataset_cfg = getattr(config.dataset, "params", {}) + if isinstance(dataset_cfg, DictConfig): + dataset_cfg = dataset_cfg + s2s_eval_dir = getattr(dataset_cfg, "s2s_eval_dir", "MMaDA/validation_prompts/s2s") + + base_path = Path(s2s_eval_dir) + if not base_path.is_absolute(): + base_path = Path.cwd() / base_path + if not base_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / s2s_eval_dir + if alt_path.exists(): + base_path = alt_path + + if not base_path.exists(): + logger.warning(f"S2S evaluation directory '{s2s_eval_dir}' not found. Skipping S2S evaluation.") + return + + audio_exts = {".wav", ".flac", ".mp3", ".ogg", ".m4a"} + wav_files = sorted(p for p in base_path.iterdir() if p.is_file() and p.suffix.lower() in audio_exts) + if not wav_files: + logger.warning(f"No audio files found in '{base_path}'. Skipping S2S evaluation.") + return + + condition = getattr(dataset_cfg, "s2s_eval_condition", "gender-female_emotion-neutral_speed-normal_pitch-normal") + mask_token_id = unwrapped_model.config.mask_token_id + codebook_size = int(getattr(config.model.omada, "codebook_size", 8192)) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_size + audio_codebook_size = 4096 + if audio_codebook_size <= 0: + logger.warning("Computed audio codebook size is non-positive. Skipping S2S evaluation.") + return + + a_tokens = uni_prompting.text_tokenizer("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", return_tensors="pt").input_ids + soa_len = uni_prompting.sptids_dict['<|soa|>'].numel() + eoa_len = uni_prompting.sptids_dict['<|eoa|>'].numel() + asst_header_len = a_tokens.shape[1] + max_audio_len = getattr(uni_prompting, "max_audio_len", 256) + max_generatable = 256 + + offset = len(uni_prompting.text_tokenizer) + codebook_size + device = accelerator.device + + output_root = Path(config.experiment.output_dir) / "eval_s2s" / f"step_{global_step}" + output_root.mkdir(parents=True, exist_ok=True) + + table = wandb.Table(columns=["Audio ID", "Source Audio", "Generated Audio", "Token Count"]) + + for audio_path in wav_files: + try: + user_tokens = vq_model_audio.encode(str(audio_path)) + except Exception as exc: + logger.error(f"Failed to encode '{audio_path}': {exc}") + continue + + if not isinstance(user_tokens, torch.Tensor): + user_tokens = torch.tensor(user_tokens) + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + + user_tokens = user_tokens.to(device=device, dtype=torch.long) + if user_tokens.numel() == 0: + logger.warning(f"Encoded audio from '{audio_path}' produced no tokens. Skipping sample.") + continue + + # Use a fixed assistant placeholder length so generation is not limited by user input duration. + assistant_len = max_generatable + if assistant_len <= 0: + logger.warning(f"Assistant placeholder length for '{audio_path}' is non-positive. Skipping sample.") + continue + + user_shifted = user_tokens + offset + assistant_placeholder = torch.full( + (1, assistant_len), + mask_token_id, + dtype=torch.long, + device=device, + ) + + input_ids, attention_mask = uni_prompting( + ([user_shifted], [assistant_placeholder]), + 's2s_gen' + ) + + try: + generated_sequences = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=256, + steps=256, + block_length=256, + temperature=config.training.get("s2s_generation_temperature", 1.0), + cfg_scale=config.training.get("s2s_guidance_scale", 3.0), + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=audio_codebook_size, + ) + except Exception as exc: + logger.error(f"Generation failed for '{audio_path}': {exc}") + continue + + if not generated_sequences: + logger.warning(f"No tokens generated for '{audio_path}'. Skipping sample.") + continue + + gen_tokens = generated_sequences[0] + if isinstance(gen_tokens, torch.Tensor): + gen_tokens = gen_tokens.detach().cpu() + token_list = gen_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list empty for '{audio_path}'. Skipping sample.") + continue + + speech_unit_str = "".join([f"<|speech_{int(token)}|>" for token in token_list]) + output_path = output_root / f"{audio_path.stem}_reply.wav" + + try: + vq_model_audio.decode(speech_unit_str, condition=condition, output_wav_file=str(output_path)) + except Exception as exc: + logger.error(f"Decoding failed for '{audio_path}': {exc}") + continue + + table.add_data( + audio_path.name, + wandb.Audio(str(audio_path), caption="source"), + wandb.Audio(str(output_path), caption="generated"), + len(token_list), + ) + + row_count = getattr(table, "num_rows", None) + if row_count is None: + table_data = getattr(table, "data", None) + row_count = len(table_data) if table_data is not None else 0 + + if row_count > 0: + accelerator.log({"eval/s2s_samples": table}, step=global_step) + else: + logger.warning("S2S evaluation produced no loggable samples.") + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ TEXT EVALUATION LOGIC ++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_text(model, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running Text Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + dataset_cfg = getattr(config.dataset, "params", {}) + prompts_file = getattr(dataset_cfg, "text_eval_prompts_file", "MMaDA/validation_prompts/math.txt") + + prompts_path = Path(prompts_file) + if not prompts_path.is_absolute(): + prompts_path = Path.cwd() / prompts_path + if not prompts_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / prompts_file + if alt_path.exists(): + prompts_path = alt_path + + if not prompts_path.exists(): + logger.warning(f"Text evaluation prompts file '{prompts_file}' not found. Skipping text evaluation.") + return + + try: + with open(prompts_path, "r", encoding="utf-8") as handle: + raw_prompts = [line.strip() for line in handle if line.strip()] + except OSError as exc: + logger.warning(f"Failed to read text evaluation prompts from '{prompts_path}': {exc}. Skipping text evaluation.") + return + + if not raw_prompts: + logger.warning("Text evaluation prompt list is empty. Skipping text evaluation.") + return + + max_samples = getattr(config.experiment, "eval_num_text_samples", 4) + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 4 + questions = raw_prompts[:max_samples] + + chat_prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for question in questions + ] + + tokenizer = uni_prompting.text_tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + answers: list[str] = [] + + for chat_prompt in chat_prompts: + tokens = tokenizer( + chat_prompt, + return_tensors="pt", + padding=True, + truncation=True, + ) + + input_ids = tokens["input_ids"].to(accelerator.device) + out = generate(unwrapped_model, input_ids, steps=128, gen_length=128, block_length=128, temperature=1, cfg_scale=0., remasking='low_confidence') + answer = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True) + + answers.append(answer) + + table = wandb.Table(columns=["Index", "Question", "Answer"]) + for idx, (question, answer) in enumerate(zip(questions, answers)): + table.add_data(idx, question, answer) + + accelerator.log({"eval/text_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(32)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation"): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + eos_token = uni_prompting.text_tokenizer.eos_token + eos_marker = eos_token if eos_token is not None else "" + for i in range(len(decoded_texts)): + full_text = decoded_texts[i] + eos_idx = full_text.find(eos_marker) + cleaned_text = full_text[:eos_idx] if eos_idx != -1 else full_text + cleaned_text = cleaned_text.replace(eos_marker, "").strip() + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": cleaned_text, + }) + + if not local_results: + logger.warning("S2T evaluation produced no results.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["decoded_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + + samples_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for idx, res in enumerate(local_results): + sample_id = res.get("sample_id", idx) + samples_table.add_data(sample_id, res["gt_text"], res["decoded_text"]) + + accelerator.log({"eval/s2t_samples": samples_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(8)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for T2S: user prompt + text + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 384 - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=5.0, + temperature=1.0, + timesteps=50, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=383, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + if not local_results: + logger.warning("Skipping T2S evaluation logging because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """Text-to-speech evaluation using the MMU-style block refinement decoder.""" + + if not accelerator.is_main_process: + return + + logger.info("***** Running T2S Evaluation (MMU-style decoder) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(8)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as exc: + logger.error(f"Failed to load T2S dataset or Whisper model for MMU-style eval: {exc}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_mmu") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + + codebook_size = config.model.omada.codebook_size + speech_vocab_size = 4096 + + for batch in tqdm(t2s_dataloader, desc="T2S MMU Eval"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + + batch_size = len(prompts) + speech_token_length = 384 - 1 + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + output_ids = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=speech_token_length, + steps=384 - 1, + block_length=384 - 1, + temperature=1.0, + cfg_scale=3.5, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=speech_vocab_size, + ) + + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + if isinstance(gen_speech_tokens, torch.Tensor): + gen_speech_tokens = gen_speech_tokens.detach().cpu() + + token_list = gen_speech_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]} (MMU eval). Skipping.") + continue + + speech_unit_str = " ".join(map(str, token_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}_mmu.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + try: + vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + except Exception as exc: + logger.error(f"Decoding failed for sample {sample_ids[i]} (MMU eval): {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": output_wav_path, + }) + + if not local_results: + logger.warning("Skipping T2S MMU-style evaluation because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S (MMU-style) Final WER: {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_mmu_like_wer": wer, + "eval/t2s_mmu_like_word_errors": errors, + "eval/t2s_mmu_like_total_words": words, + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_mmu_like_samples": results_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2s(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running V2S Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + except Exception as exc: + logger.error(f"Failed to load Whisper model for V2S eval: {exc}") + return + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2S eval root '{video_root}' not found. Skipping V2S evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2S evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Whisper Transcript", "Generated Audio"]) + + device = unwrapped_model.device + mask_token_id = unwrapped_model.config.mask_token_id + eoa_token_id = int(uni_prompting.sptids_dict['<|eoa|>'][0].item()) + audio_codebook_size = 4096 + max_audio_tokens = int(getattr(uni_prompting, "max_audio_len_short", config.dataset.preprocessing.max_aud_length_short)) + max_new_tokens = max(1, max_audio_tokens - 1) + block_length = 128 if max_new_tokens % 128 == 0 else max_new_tokens + + output_dir = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio_v2s", f"step_{global_step}") + os.makedirs(output_dir, exist_ok=True) + + for file_name in tqdm(file_list[:], desc="V2S Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for idx in range(total_frames): + ret, frame = cap.read() + if idx in indices: + if not ret: + continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: + logger.warning(f"Skipping {file_name}: insufficient frames.") + continue + + video_tensor = torch.stack(frames).to(device) + try: + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + except Exception as exc: + logger.error(f"Failed to encode video {file_name}: {exc}") + continue + video_tokens = video_tokens.view(1, -1) + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + audio_placeholder = torch.full( + (1, max_new_tokens), + mask_token_id, + dtype=torch.long, + device=device, + ) + try: + seq_ids, attn_mask = uni_prompting( + (video_tokens, [prompt_text], [audio_placeholder]), + 'v2s_gen' + ) + except Exception as exc: + logger.error(f"Prompt construction failed for {file_name}: {exc}") + continue + + input_ids = seq_ids.to(device) + attention_mask = attn_mask.to(device) + + try: + generated_list = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=256, + steps=256, + block_length=256, + temperature=1.0, + cfg_scale=3.0, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=config.model.omada.codebook_size, + audio_codebook_size=audio_codebook_size, + ) + except Exception as exc: + logger.error(f"Generation failed for {file_name}: {exc}") + continue + + generated_tokens = generated_list[0] + if isinstance(generated_tokens, torch.Tensor): + generated_tokens = generated_tokens.detach().cpu() + + collected = [] + for token in generated_tokens.tolist(): + if token == eoa_token_id or token >= audio_codebook_size: + break + if token >= 0: + collected.append(token) + + if not collected: + logger.warning(f"No valid audio tokens generated for {file_name}.") + continue + + speech_unit_for_decode = "".join(f"<|speech_{tok}|>" for tok in collected) + output_wav_path = os.path.join(output_dir, f"{Path(file_name).stem}_v2s.wav") + try: + vq_model_audio.decode( + speech_unit_for_decode, + condition='gender-female_emotion-neutral_speed-normal_pitch-normal', + output_wav_file=output_wav_path + ) + except Exception as exc: + logger.error(f"Decoding failed for {file_name}: {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + results_table.add_data( + file_name, + question, + whisper_text, + wandb.Audio(output_wav_path, caption=whisper_text) + ) + + if len(results_table.data) == 0: + logger.warning("V2S evaluation produced no samples to log.") + return + + accelerator.log({"eval/v2s_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + if accelerator.is_main_process: + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + + if accelerator.is_main_process: + evaluate_v2s(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_text(model, uni_prompting, config, accelerator, global_step) + evaluate_t2i(model, vq_model_image, uni_prompting, config, accelerator, global_step) + evaluate_i2i(model, vq_model_image, uni_prompting, config, accelerator, global_step) + evaluate_s2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + _configure_multiprocessing() + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = ( + config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s + + config.training.batch_size_s2s + ) - 1 # -1 since t2s/ s2t choice + + total_batch_size = ( + total_batch_size_per_gpu + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, max_audio_len_short=config.dataset.preprocessing.max_aud_length_short, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|i2i|>", "<|v2t|>", "<|v2s|>", "<|s2t|>", "<|t2s|>", "<|s2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + + # speech_vocab_start = int(config.model.omada.llm_vocab_size) + int(config.model.omada.codebook_size) # 126464 + 8192 = 134656 + # audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) # 4096 + + logger.info(f"SPEECHVOCABSTART: {speech_vocab_start}") + logger.info(f"int(config.model.omada.new_vocab_size): {int(config.model.omada.new_vocab_size)}") + logger.info(f"AUDIOCODEBOOKSIZE: {audio_codebook_size}") + + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + # Speech-token caching configuration + speech_cache_cfg = getattr(config.dataset, "speech_token_cache", {}) + if not isinstance(speech_cache_cfg, dict): + speech_cache_cfg = OmegaConf.to_container(speech_cache_cfg, resolve=True) + speech_cache_cfg = speech_cache_cfg or {} + + speech_cache_enabled = bool(speech_cache_cfg.get("enable", False)) + speech_cache_dir: Optional[Path] + if speech_cache_enabled: + cache_root = speech_cache_cfg.get("root", "cache/speech_tokens") + speech_cache_dir = Path(cache_root) + try: + speech_cache_dir.mkdir(parents=True, exist_ok=True) + except OSError: + speech_cache_dir = None + speech_cache_enabled = False + logger.warning("Failed to create speech cache directory at %s; disabling cache.", cache_root) + else: + speech_cache_dir = None + + speech_cache_max_items = int(speech_cache_cfg.get("max_items_in_memory", 4096)) + audio_token_cache_mem: Dict[str, torch.Tensor] = {} + + def _get_audio_cache_path(audio_path: Union[str, Path]) -> Optional[Path]: + if not isinstance(audio_path, (str, os.PathLike)): + return None + if not speech_cache_enabled or speech_cache_dir is None: + return None + key = os.path.abspath(str(audio_path)) + digest = hashlib.sha1(key.encode("utf-8")).hexdigest() + subdir = speech_cache_dir / digest[:2] / digest[2:4] + return subdir / f"{digest}.pt" + + def _load_cached_audio_tokens(audio_path: Union[str, Path]) -> Optional[torch.Tensor]: + if not isinstance(audio_path, (str, os.PathLike)): + return None + cache_key = os.path.abspath(str(audio_path)) + cached = audio_token_cache_mem.get(cache_key) + if cached is not None: + return cached.clone() + + cache_path = _get_audio_cache_path(audio_path) + if cache_path is None or not cache_path.exists(): + return None + try: + tokens = torch.load(cache_path, map_location="cpu") + if isinstance(tokens, torch.Tensor): + if len(audio_token_cache_mem) < speech_cache_max_items: + audio_token_cache_mem[cache_key] = tokens + return tokens.clone() + except Exception as exc: + logger.warning("Failed to load cached speech tokens from %s (%s); ignoring cache.", cache_path, exc) + return None + + def _store_cached_audio_tokens(audio_path: Union[str, Path], tokens: torch.Tensor) -> None: + if not isinstance(audio_path, (str, os.PathLike)): + return + cache_path = _get_audio_cache_path(audio_path) + if cache_path is None: + return + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") + torch.save(tokens.cpu(), tmp_path) + os.replace(tmp_path, cache_path) + except Exception as exc: + logger.warning("Failed to write speech token cache to %s (%s).", cache_path, exc) + return + cache_key = os.path.abspath(str(audio_path)) + if len(audio_token_cache_mem) < speech_cache_max_items: + audio_token_cache_mem[cache_key] = tokens.cpu() + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16, config='/home/work/AIDAS/ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model/config.json').to(accelerator.device) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + def build_distributed_sampler(dataset, *, shuffle=True, drop_last=True): + """Create a DistributedSampler only when running with multiple processes.""" + if dataset is None or accelerator.num_processes <= 1: + return None + return DistributedSampler( + dataset, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=shuffle, + drop_last=drop_last, + ) + + batch_size_t2i_cfg = config.training.batch_size_t2i + batch_size_lm_cfg = config.training.batch_size_lm + batch_size_mmu_cfg = config.training.batch_size_mmu + batch_size_t2s_cfg = config.training.batch_size_t2s + batch_size_s2t_cfg = config.training.batch_size_s2t + batch_size_v2t_cfg = config.training.batch_size_v2t + batch_size_s2s_cfg = config.training.batch_size_s2s + batch_size_v2s_cfg = batch_size_v2t_cfg + + total_batch_size = ( + total_batch_size_per_gpu + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + pin_memory = bool(getattr(dataset_config, "pin_memory", False)) + persistent_workers = bool(getattr(dataset_config, "persistent_workers", False)) + dataloader_timeout = int(getattr(dataset_config, "dataloader_timeout", 120)) + + if persistent_workers and dataloader_timeout > 0: + logger.warning( + "persistent_workers=True requires dataloader_timeout=0; overriding timeout=%s", + dataloader_timeout, + ) + dataloader_timeout = 0 + + if ( + not persistent_workers + and int(getattr(dataset_config, "num_workers", 0)) > 0 + and str(config.dataset.combined_loader_mode) == "max_size_cycle" + ): + logger.warning( + "Using combined_loader_mode='max_size_cycle' with num_workers>0 and " + "persistent_workers=False can exhaust OS semaphores when loaders cycle. " + "Set dataset.params.persistent_workers=True to keep worker processes alive." + ) + + # Text-to-image / Image-to-image datasets + logger.info("Loading Text-to-image / Image-to-image datasets") + dataset_t2i = None + dataset_i2i = None + train_dataloader_t2i = None + train_dataloader_i2i = None + sampler_t2i: Optional[DistributedSampler] = None # type: ignore[assignment] + sampler_i2i: Optional[DistributedSampler] = None # type: ignore[assignment] + if batch_size_t2i_cfg > 0: + raw_t2i_choice = dataset_config.get("t2i_dataset", "hqedit") + if isinstance(raw_t2i_choice, str): + split_tokens = [token.strip() for token in raw_t2i_choice.replace(",", "+").split("+")] + dataset_choices = [token for token in split_tokens if token] + else: + dataset_choices = [str(token).strip() for token in raw_t2i_choice if str(token).strip()] + + if not dataset_choices: + raise ValueError("t2i_dataset configuration produced no valid dataset names.") + + t2i_datasets: list[Dataset] = [] + i2i_datasets: list[Dataset] = [] + t2i_source_names: list[str] = [] + i2i_source_names: list[str] = [] + for choice in dataset_choices: + choice_lower = choice.lower() + if choice_lower in {"hqedit", "hq-edit", "hq_edit"}: + i2i_datasets.append( + HQEditX2IDataset( + split=dataset_config.get("hqedit_split", "train"), + resolution=dataset_config.resolution, + ) + ) + logger.info("Using HQ-Edit dataset for T2I/i2i branch (%s split)", dataset_config.get("hqedit_split", "train")) + i2i_source_names.append(choice) + elif choice_lower in {"text2image2m", "text-to-image-2m", "text_to_image_2m"}: + t2i_datasets.append( + TextToImage2MDataset( + split=dataset_config.get("t2i_split", "train"), + resolution=dataset_config.resolution, + dataset_name=dataset_config.get("t2i_dataset_name", "jackyhate/text-to-image-2M"), + cache_dir=dataset_config.get("t2i_cache_dir", None), + local_files_only=bool(dataset_config.get("t2i_local_files_only", False)), + ) + ) + logger.info( + "Using text-to-image-2M dataset for T2I branch (split=%s, dataset=%s)", + dataset_config.get("t2i_split", "train"), + dataset_config.get("t2i_dataset_name", "jackyhate/text-to-image-2M"), + ) + t2i_source_names.append(choice) + elif choice_lower in {"openimage", "openimage_i2i", "openimage-edit", "openimage_local"}: + raw_openimage_cfg = getattr(dataset_config, "openimage_i2i", None) + if raw_openimage_cfg is None: + raise ValueError("dataset.params.openimage_i2i must be configured to use the OpenImage dataset.") + if not isinstance(raw_openimage_cfg, dict): + openimage_cfg = OmegaConf.to_container(raw_openimage_cfg, resolve=True) or {} + else: + openimage_cfg = raw_openimage_cfg + + i2i_datasets.append( + OpenImageI2IDataset( + resolution=dataset_config.resolution, + image_root=openimage_cfg.get("image_root"), + sft_jsonl=openimage_cfg.get("sft_jsonl"), + pref_jsonl=openimage_cfg.get("pref_jsonl"), + multi_turn_jsonl=openimage_cfg.get("multi_turn_jsonl"), + prefer_summarized_text=bool(openimage_cfg.get("prefer_summarized_text", True)), + pref_positive_only=bool(openimage_cfg.get("pref_positive_only", True)), + skip_missing=bool(openimage_cfg.get("skip_missing", True)), + max_samples_per_source=openimage_cfg.get("max_samples_per_source"), + max_total_samples=openimage_cfg.get("max_total_samples"), + seed=openimage_cfg.get("seed"), + ) + ) + cfg_paths = [ + openimage_cfg.get("sft_jsonl"), + openimage_cfg.get("pref_jsonl"), + openimage_cfg.get("multi_turn_jsonl"), + ] + cfg_paths = [str(path) for path in cfg_paths if path] + logger.info( + "Using OpenImage local edit dataset for i2i (jsonl=%s)", + ", ".join(cfg_paths) if cfg_paths else "n/a", + ) + i2i_source_names.append(choice) + else: + raise ValueError(f"Unsupported t2i_dataset '{choice}'") + + if t2i_datasets: + dataset_t2i = ( + t2i_datasets[0] + if len(t2i_datasets) == 1 + else CombinedX2IDataset(t2i_datasets) + ) + logger.info( + "T2I dataloading sources: %s", + ", ".join(t2i_source_names) if t2i_source_names else "n/a", + ) + + sampler_t2i = build_distributed_sampler( + dataset_t2i, + shuffle=True, + drop_last=True, + ) + + train_dataloader_t2i = DataLoader( + dataset_t2i, + batch_size=batch_size_t2i_cfg, + sampler=sampler_t2i, + shuffle=sampler_t2i is None, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_x2i, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if i2i_datasets: + dataset_i2i = ( + i2i_datasets[0] + if len(i2i_datasets) == 1 + else CombinedX2IDataset(i2i_datasets) + ) + logger.info( + "I2I dataloading sources: %s", + ", ".join(i2i_source_names) if i2i_source_names else "n/a", + ) + + sampler_i2i = build_distributed_sampler( + dataset_i2i, + shuffle=True, + drop_last=True, + ) + + train_dataloader_i2i = DataLoader( + dataset_i2i, + batch_size=batch_size_t2i_cfg, + sampler=sampler_i2i, + shuffle=sampler_i2i is None, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_x2i, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Language modeling dataset (HF instruction mixture) + logger.info("Loading LM dataset") + dataset_lm = None + train_dataloader_lm = None + if batch_size_lm_cfg > 0: + instruction_cfg = getattr(dataset_config, "hf_instruction_lm", {}) + if not isinstance(instruction_cfg, dict): + instruction_cfg = OmegaConf.to_container(instruction_cfg, resolve=True) + instruction_cfg = instruction_cfg or {} + + seed_lm = instruction_cfg.get("seed") + if seed_lm is None: + seed_lm = getattr(config.training, "seed", 42) or 42 + + dataset_lm = HFInstructionTextDataset( + split=instruction_cfg.get("split", "train"), + max_samples_per_source=instruction_cfg.get("max_samples_per_source"), + max_total_samples=instruction_cfg.get("max_total_samples"), + seed=int(seed_lm), + ) + + sampler_lm = build_distributed_sampler( + dataset_lm, + shuffle=True, + drop_last=True, + ) + + train_dataloader_lm = DataLoader( + dataset_lm, + batch_size=batch_size_lm_cfg, + sampler=sampler_lm, + shuffle=sampler_lm is None, + collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Video Dataset + logger.info("Loading Video dataset") + dataset_v2t = None + dataset_v2s = None + train_dataloader_v2t = None + train_dataloader_v2s = None + sampler_v2t = None + sampler_v2s = None + speech_cfg = getattr(dataset_config, "video_speech_dataset", {}) + if not isinstance(speech_cfg, dict): + speech_cfg = OmegaConf.to_container(speech_cfg, resolve=True) + speech_cfg = speech_cfg or {} + + if batch_size_v2t_cfg > 0: + dataset_v2t = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + dataset_name=speech_cfg.get("llavavid_dataset_name", "llavavid"), + llavavid_path=speech_cfg.get("llavavid_path", "lmms-lab/LLaVA-Video-178K"), + num_frames=8, + llavavid_local_files_only=bool(speech_cfg.get("llavavid_local_files_only", False)), + llavavid_skip_configs=speech_cfg.get("llavavid_skip_configs"), + llavavid_skip_video_patterns=speech_cfg.get("llavavid_skip_video_patterns"), + ) + + sampler_v2t = build_distributed_sampler( + dataset_v2t, + shuffle=True, + drop_last=True, + ) + + train_dataloader_v2t = DataLoader( + dataset_v2t, + batch_size=batch_size_v2t_cfg, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_v2t, + sampler=sampler_v2t, + shuffle=sampler_v2t is None, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if batch_size_v2s_cfg > 0: + dataset_v2s = VideoSpeechDataset( + transform=image_transform, + resolution=preproc_config.resolution, + num_frames=speech_cfg.get("num_frames_speech", 4), + video_root=speech_cfg.get( + "video_root", "/home/work/AIDAS/data/video/openvid1m/video/video" + ), + audio_root=speech_cfg.get( + "audio_root", "/home/work/AIDAS/data/video-speech" + ), + speech_dir_name=speech_cfg.get("speech_dir_name", "openvid-speech-trunc"), + index_path=speech_cfg.get( + "index_path", "/home/work/AIDAS/data/video-speech/openvid-speech.csv" + ), + sample_method=speech_cfg.get("sample_method", "uniform"), + precomputed_tokens_root=( + speech_cfg.get("precomputed_tokens_root") + if speech_cfg.get("use_precomputed_tokens", False) + else None + ), + ) + + sampler_v2s = build_distributed_sampler( + dataset_v2s, + shuffle=True, + drop_last=True, + ) + + train_dataloader_v2s = DataLoader( + dataset_v2s, + batch_size=batch_size_v2s_cfg, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_v2s, + sampler=sampler_v2s, + shuffle=sampler_v2s is None, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Speech Dataset + logger.info("Loading Speech dataset") + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + # Speech-to-Speech Dataset (EMOVA + Instruct S2S) + dataset_s2s = None + sampler_s2s = None + train_dataloader_s2s = None + if config.training.batch_size_s2s > 0: + dataset_s2s = Speech2SpeechDataset(dataset_config.get("speech2speech", [])) + + # Multi-image interleaved dataset (MMU-style) + logger.info("Loading MMU dataset") + dataset_mmu = None + sampler_mmu = None + train_dataloader_mmu = None + if config.training.batch_size_mmu > 0: + mmu_params = dataset_config.get("mmu_interleaved", {}) + if mmu_params is None: + mmu_kwargs = {} + elif isinstance(mmu_params, dict): + mmu_kwargs = mmu_params + else: + mmu_kwargs = OmegaConf.to_container(mmu_params, resolve=True) + dataset_mmu = TextImageInterleavedDataset(**mmu_kwargs) + + logger.info("Dataset Prepared.") + + # Use distinct DistributedSamplers for each speech dataloader to avoid iterator interference + if accelerator.num_processes > 1: + sampler_s2t = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + sampler_t2s = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + if dataset_s2s is not None: + sampler_s2s = DistributedSampler( + dataset_s2s, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + if dataset_mmu is not None: + sampler_mmu = DistributedSampler( + dataset_mmu, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + else: + sampler_s2t = None + sampler_t2s = None + sampler_s2s = None + sampler_mmu = None + + train_dataloader_s2t = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_s2t, + shuffle=False, + sampler=sampler_s2t, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + train_dataloader_t2s = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_t2s, + shuffle=False, + sampler=sampler_t2s, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if dataset_s2s is not None: + train_dataloader_s2s = DataLoader( + dataset_s2s, + batch_size=batch_size_s2s_cfg, + shuffle=False, + sampler=sampler_s2s, + collate_fn=s2s_collate_fn, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if dataset_mmu is not None: + train_dataloader_mmu = DataLoader( + dataset_mmu, + batch_size=config.training.batch_size_mmu, + shuffle=False, + sampler=sampler_mmu, + collate_fn=collate_fn_mmu_mult, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Combine these dataloaders into a single iterable model + iterables = {} + if train_dataloader_lm is not None: + iterables["lm_flow"] = train_dataloader_lm + if train_dataloader_mmu is not None: + iterables["mmu_flow"] = train_dataloader_mmu + if train_dataloader_s2s is not None: + iterables["s2s_flow"] = train_dataloader_s2s + + if not iterables: + raise ValueError( + "CombinedLoader requires at least one non-speech iterable when speech flows are randomized. " + "Enable another dataset (e.g., t2i, lm, mmu) or disable speech randomization." + ) + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + def _num_steps(dataset_obj, batch_size_cfg): + if dataset_obj is None or batch_size_cfg <= 0: + return 0 + total_bs = batch_size_cfg * accelerator.num_processes * config.training.gradient_accumulation_steps + if total_bs <= 0: + return 0 + length = len(dataset_obj) + if length == 0: + return 0 + return math.ceil(length / total_bs) + + num_update_steps_per_epoch_t2i = _num_steps(dataset_t2i, config.training.batch_size_t2i) + num_update_steps_per_epoch_i2i = _num_steps(dataset_i2i, config.training.batch_size_t2i) + num_update_steps_per_epoch_lm = _num_steps(dataset_lm, config.training.batch_size_lm) + num_update_steps_per_epoch_s2t = _num_steps(dataset_sm, config.training.batch_size_s2t) + num_update_steps_per_epoch_t2s = _num_steps(dataset_sm, config.training.batch_size_t2s) + num_update_steps_per_epoch_s2s = _num_steps(dataset_s2s, batch_size_s2s_cfg) + num_update_steps_per_epoch_v2t = _num_steps(dataset_v2t, batch_size_v2t_cfg) + num_update_steps_per_epoch_v2s = _num_steps(dataset_v2s, batch_size_v2s_cfg) + num_update_steps_per_epoch_mmu = _num_steps(dataset_mmu, config.training.batch_size_mmu) + + # Calculate num_train_epochs + num_update_steps_per_epoch = max( + num_update_steps_per_epoch_t2i, + num_update_steps_per_epoch_lm, + num_update_steps_per_epoch_s2t, + num_update_steps_per_epoch_t2s, + num_update_steps_per_epoch_v2t, + num_update_steps_per_epoch_v2s, + num_update_steps_per_epoch_s2s, + num_update_steps_per_epoch_mmu, + num_update_steps_per_epoch_i2i, + ) + + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + logger.info(f"len of T2I: {len(dataset_t2i) if dataset_t2i is not None else 0}") + logger.info(f"len of I2I: {len(dataset_i2i) if dataset_i2i is not None else 0}") + logger.info(f"len of LM: {len(dataset_lm)}") + logger.info(f"len of Speech: {len(dataset_sm)}") + logger.info(f"len of Video Caption: {len(dataset_v2t) if dataset_v2t is not None else 0}") + logger.info(f"len of Video Speech: {len(dataset_v2s) if dataset_v2s is not None else 0}") + logger.info(f"len of S2S: {len(dataset_s2s)}") + logger.info(f"len of MMU: {len(dataset_mmu)}") + + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def _maybe_trim_audio_file(audio_path: Union[str, os.PathLike], max_duration: float) -> tuple[Union[str, os.PathLike], Optional[str]]: + """Return a path to an audio file trimmed to max_duration seconds. + + If trimming succeeds, returns (trimmed_path, temp_path) where trimmed_path is the + file to use for encoding and temp_path should be deleted afterwards. If trimming + fails, returns (audio_path, None). + """ + if max_duration <= 0: + return audio_path, None + trim_timeout = float(getattr(config.dataset.preprocessing, "audio_trim_timeout_sec", 30.0)) + try: + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp_path = tmp.name + tmp.close() + cmd = [ + "ffmpeg", + "-y", + "-hide_banner", + "-loglevel", + "error", + "-i", + str(audio_path), + "-t", + str(max_duration), + "-c", + "copy", + tmp_path, + ] + subprocess.run(cmd, check=True, timeout=trim_timeout) + return tmp_path, tmp_path + except Exception as exc: + warnings.warn(f"Failed to trim audio {audio_path} to {max_duration}s: {exc}") + try: + if 'tmp_path' in locals() and os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError: + pass + return audio_path, None + + def _format_path_for_log(path: Union[str, os.PathLike, torch.Tensor, None]) -> str: + if isinstance(path, (str, os.PathLike)): + try: + return os.fspath(path) + except TypeError: + return str(path) + if isinstance(path, torch.Tensor): + return f"" + if isinstance(path, np.ndarray): + return f"" + if isinstance(path, Sequence) and not isinstance(path, (str, bytes, os.PathLike)): + try: + return f"" + except Exception: + return "" + return repr(path) + + def safe_audio_encode(audio_path: Union[str, torch.Tensor, np.ndarray, Sequence[int]], flow_name: str): + if isinstance(audio_path, torch.Tensor): + return audio_path.cpu().clone(), None + if isinstance(audio_path, np.ndarray): + try: + tensor = torch.from_numpy(audio_path).to(dtype=torch.long) + except Exception as exc: + raise RuntimeError(f"Failed to convert numpy audio tokens to tensor for flow '{flow_name}': {exc}") from exc + return tensor, None + if isinstance(audio_path, Sequence) and not isinstance(audio_path, (str, bytes, os.PathLike)): + try: + tensor = torch.as_tensor(audio_path, dtype=torch.long) + except Exception as exc: + raise RuntimeError(f"Failed to convert cached audio tokens to tensor for flow '{flow_name}': {exc}") from exc + return tensor, None + path_repr = _format_path_for_log(audio_path) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode request: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + + max_retries = int(getattr(config.dataset.preprocessing, "audio_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "audio_encode_retry_backoff_sec", 0.5)) + duration_limit = float(getattr(config.dataset.preprocessing, "max_audio_duration_sec", 15.0)) + + cached = _load_cached_audio_tokens(audio_path) + if cached is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode hit cache: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + return cached, None + + for attempt in range(1, max_retries + 1): + trimmed_path: Union[str, os.PathLike] = audio_path + temp_path: Optional[str] = None + try: + if isinstance(audio_path, (str, os.PathLike)): + trimmed_path, temp_path = _maybe_trim_audio_file(audio_path, duration_limit) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode attempt %d/%d (trimmed=%s): %s", + accelerator.process_index, + flow_name, + attempt, + max_retries, + "yes" if temp_path is not None else "no", + _format_path_for_log(trimmed_path), + ) + tokens = vq_model_audio.encode(str(trimmed_path)).cpu() + _store_cached_audio_tokens(audio_path, tokens) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode success: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + return tokens, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + finally: + if temp_path is not None and os.path.exists(temp_path): + try: + os.remove(temp_path) + except OSError: + pass + + def safe_video_get_code(video_tensor_sample: torch.Tensor, sample_index: int): + max_retries = int(getattr(config.dataset.preprocessing, "video_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "video_encode_retry_backoff_sec", 0.5)) + for attempt in range(1, max_retries + 1): + try: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] video encode request sample=%d attempt=%d/%d", + accelerator.process_index, + sample_index, + attempt, + max_retries, + ) + video_token = vq_model_image.get_code(video_tensor_sample) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] video encode success sample=%d", + accelerator.process_index, + sample_index, + ) + return video_token, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] v2t video encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + logger.warning( + "[rank %s] video encode retry sample=%d attempt=%d/%d error=%s", + accelerator.process_index, + sample_index, + attempt, + max_retries, + exc, + ) + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + + def safe_image_get_code(image_tensor_sample: torch.Tensor, sample_index: int): + max_retries = int(getattr(config.dataset.preprocessing, "image_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "image_encode_retry_backoff_sec", 0.5)) + for attempt in range(1, max_retries + 1): + try: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] image encode request sample=%d attempt=%d/%d", + accelerator.process_index, + sample_index, + attempt, + max_retries, + ) + if image_tensor_sample.dim() == 3: + image_tensor_sample = image_tensor_sample.unsqueeze(0) + elif image_tensor_sample.dim() != 4: + raise ValueError( + f"Expected image tensor with 3 or 4 dims, got shape {tuple(image_tensor_sample.shape)}" + ) + image_token = vq_model_image.get_code(image_tensor_sample) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] image encode success sample=%d", + accelerator.process_index, + sample_index, + ) + return image_token, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] s2s image encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + logger.warning( + "[rank %s] image encode retry sample=%d attempt=%d/%d error=%s", + accelerator.process_index, + sample_index, + attempt, + max_retries, + exc, + ) + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + + def _decode_single_image(single_like): + if single_like is None: + return None + if isinstance(single_like, Image.Image): + return single_like.convert('RGB') + + data_bytes = None + + if isinstance(single_like, (bytes, bytearray)): + data_bytes = bytes(single_like) + elif isinstance(single_like, str): + try: + data_bytes = base64.b64decode(single_like) + except (binascii.Error, ValueError): + if os.path.isfile(single_like): + try: + with open(single_like, 'rb') as fh: + data_bytes = fh.read() + except OSError: + data_bytes = None + elif isinstance(single_like, dict): + binary_payload = single_like.get('bytes') + if binary_payload is not None: + data_bytes = binary_payload + else: + path_value = single_like.get('path') + if path_value and os.path.isfile(path_value): + try: + with open(path_value, 'rb') as fh: + data_bytes = fh.read() + except OSError: + data_bytes = None + + if data_bytes is None: + return None + + try: + with Image.open(BytesIO(data_bytes)) as img: + return img.convert('RGB') + except Exception: + return None + + def maybe_decode_image(image_like): + if isinstance(image_like, (list, tuple)): + return [_decode_single_image(item) for item in image_like] + return _decode_single_image(image_like) + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_i2i( + source_images: torch.FloatTensor, + target_images: torch.FloatTensor, + prompts: list[str], + is_train: bool = True, + ): + """Build masked i2i sequences from source/target image pairs.""" + + # Tokenize source/target images with VQ model and offset by text vocab size + source_tokens = vq_model_image.get_code(source_images) + len(uni_prompting.text_tokenizer) + target_tokens = vq_model_image.get_code(target_images) + len(uni_prompting.text_tokenizer) + + cond_dropout_prob = config.training.get( + "i2i_cond_dropout_prob", + config.training.cond_dropout_prob, + ) + + if is_train and torch.rand(1, device=source_tokens.device).item() < cond_dropout_prob: + effective_prompts = [''] * len(prompts) + masked_target_source = source_tokens + else: + effective_prompts = list(prompts) + masked_target_source = target_tokens + + masked_target_tokens, labels, _, mask_prob = mask_or_random_replace_tokens( + masked_target_source, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + + input_ids, attention_masks, labels = uni_prompting( + (effective_prompts, source_tokens, masked_target_tokens, labels), + 'i2i' + ) + + return input_ids, labels, mask_prob, attention_masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_v2s( + input_ids_v2s: torch.Tensor, + prompt_masks: torch.Tensor, + labels_v2s: torch.Tensor, + eps: float = 1e-3, + ): + b, l = input_ids_v2s.shape + device_local = input_ids_v2s.device + + p_mask = eps + (1.0 - eps) * torch.rand(b, device=device_local) + p_mask = p_mask.unsqueeze(1).expand(b, l) + + rand_mat = torch.rand((b, l), device=device_local) + answer_region = (prompt_masks == 0) + masked_indices = (rand_mat < p_mask) & answer_region + + noisy_batch = input_ids_v2s.clone() + noisy_batch[masked_indices] = mask_id + + answer_lengths = answer_region.sum(dim=-1, keepdim=True) + expanded_lengths = answer_lengths.expand(b, l) + + return noisy_batch.long(), labels_v2s.long(), p_mask, expanded_lengths.long() + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2s( + input_ids_s2s, prompt_masks, labels_s2s, eps=1e-3 + ): + b, l = input_ids_s2s.shape + t = torch.rand(b, device=input_ids_s2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_s2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_s2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_s2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_s2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + v2t_iterator: Optional[Iterator] = None + v2s_iterator: Optional[Iterator] = None + t2i_iterator: Optional[Iterator] = None + i2i_iterator: Optional[Iterator] = None + + def _next_from_v2t(): + nonlocal v2t_iterator + if train_dataloader_v2t is None: + return None + try: + return next(v2t_iterator) + except StopIteration: + v2t_iterator = iter(train_dataloader_v2t) + return next(v2t_iterator) + + def _next_from_v2s(): + nonlocal v2s_iterator + if train_dataloader_v2s is None: + return None + try: + return next(v2s_iterator) + except StopIteration: + v2s_iterator = iter(train_dataloader_v2s) + return next(v2s_iterator) + + def _next_from_t2i(): + nonlocal t2i_iterator + if train_dataloader_t2i is None: + return None + try: + return next(t2i_iterator) + except StopIteration: + t2i_iterator = iter(train_dataloader_t2i) + return next(t2i_iterator) + + def _next_from_i2i(): + nonlocal i2i_iterator + if train_dataloader_i2i is None: + return None + try: + return next(i2i_iterator) + except StopIteration: + i2i_iterator = iter(train_dataloader_i2i) + return next(i2i_iterator) + + def _next_from_s2t(): + nonlocal s2t_iterator + if train_dataloader_s2t is None: + return None + try: + return next(s2t_iterator) + except StopIteration: + s2t_iterator = iter(train_dataloader_s2t) + return next(s2t_iterator) + + def _next_from_t2s(): + nonlocal t2s_iterator + if train_dataloader_t2s is None: + return None + try: + return next(t2s_iterator) + except StopIteration: + t2s_iterator = iter(train_dataloader_t2s) + return next(t2s_iterator) + + v2t_iterator = iter(train_dataloader_v2t) if train_dataloader_v2t is not None else None + v2s_iterator = iter(train_dataloader_v2s) if train_dataloader_v2s is not None else None + t2i_iterator = iter(train_dataloader_t2i) if train_dataloader_t2i is not None else None + i2i_iterator = iter(train_dataloader_i2i) if train_dataloader_i2i is not None else None + s2t_iterator = iter(train_dataloader_s2t) if train_dataloader_s2t is not None else None + t2s_iterator = iter(train_dataloader_t2s) if train_dataloader_t2s is not None else None + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + # Ensure all samplers reshuffle in a rank-consistent way each epoch + try: + if isinstance(sampler_t2i, DistributedSampler): + sampler_t2i.set_epoch(epoch) + if isinstance(sampler_i2i, DistributedSampler): + sampler_i2i.set_epoch(epoch) + if isinstance(sampler_v2t, DistributedSampler): + sampler_v2t.set_epoch(epoch) + if isinstance(sampler_v2s, DistributedSampler): + sampler_v2s.set_epoch(epoch) + if accelerator.num_processes > 1: + if sampler_s2t is not None: + sampler_s2t.set_epoch(epoch) + if sampler_t2s is not None: + sampler_t2s.set_epoch(epoch) + if sampler_s2s is not None: + sampler_s2s.set_epoch(epoch) + except Exception: + pass + model.train() + combined_iterator = iter(combined_dataloader) + while True: + skip_local = 0 + timeout_encountered = False + timeout_message: Optional[str] = None + try: + batch, batch_idx, dataloader_idx = next(combined_iterator) + except StopIteration: + break + except RuntimeError as exc: + if "DataLoader timed out" in str(exc): + skip_local = 1 + timeout_encountered = True + timeout_message = str(exc) + batch = None + batch_idx = None + dataloader_idx = None + else: + raise + + if batch is None: + skip_local = 1 + + skip_tensor = torch.tensor(skip_local, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction="sum") + if skip_sum.item() > 0: + timeout_tensor = torch.tensor(1 if timeout_encountered else 0, device=accelerator.device, dtype=torch.int32) + timeout_sum = accelerator.reduce(timeout_tensor, reduction="sum") + if accelerator.is_main_process: + if timeout_sum.item() > 0: + logger.warning( + "Skipping global step %s due to DataLoader timeout: %s", + global_step, + timeout_message or "timeout on non-main rank", + ) + else: + logger.warning( + "Skipping global step %s due to empty batch from CombinedLoader.", + global_step, + ) + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + v2t_batch = None + v2s_batch = None + selected_v2_branch: Optional[str] = None + t2i_batch = None + i2i_batch = None + selected_x2i_branch: Optional[str] = None + + v2_choices: list[str] = [] + if train_dataloader_v2t is not None: + v2_choices.append("v2t") + if train_dataloader_v2s is not None: + v2_choices.append("v2s") + + if v2_choices: + if accelerator.num_processes > 1: + local_choice = random.randrange(len(v2_choices)) if accelerator.is_main_process else 0 + choice_idx = _broadcast_choice(local_choice, accelerator) + else: + choice_idx = random.randrange(len(v2_choices)) + + if choice_idx < 0 or choice_idx >= len(v2_choices): + if accelerator.is_main_process: + logger.warning( + "Received out-of-range v2 branch index %s for %s choices; clamping.", + choice_idx, + len(v2_choices), + ) + choice_idx = choice_idx % len(v2_choices) + + selected_v2_branch = v2_choices[choice_idx] + if selected_v2_branch == "v2t": + v2t_batch = _next_from_v2t() + else: + v2s_batch = _next_from_v2s() + + batch["v2t_flow"] = v2t_batch + batch["v2s_flow"] = v2s_batch + + # Initialize speech flows with empty placeholders; they will be populated if selected. + batch["s2t_flow"] = _empty_audio_batch() + batch["t2s_flow"] = _empty_audio_batch() + + speech_choices: list[str] = [] + if train_dataloader_s2t is not None: + speech_choices.append("s2t") + if train_dataloader_t2s is not None: + speech_choices.append("t2s") + + selected_speech_branch: Optional[str] = None + if speech_choices: + if accelerator.num_processes > 1: + local_choice = random.randrange(len(speech_choices)) if accelerator.is_main_process else 0 + choice_idx = _broadcast_choice(local_choice, accelerator) + else: + choice_idx = random.randrange(len(speech_choices)) + + if choice_idx < 0 or choice_idx >= len(speech_choices): + if accelerator.is_main_process: + logger.warning( + "Received out-of-range speech branch index %s for %s choices; clamping.", + choice_idx, + len(speech_choices), + ) + choice_idx = choice_idx % len(speech_choices) + + selected_speech_branch = speech_choices[choice_idx] + if selected_speech_branch == "s2t": + speech_batch = _next_from_s2t() + if speech_batch is None: + skip_local = 1 + else: + batch["s2t_flow"] = speech_batch + else: + speech_batch = _next_from_t2s() + if speech_batch is None: + skip_local = 1 + else: + batch["t2s_flow"] = speech_batch + + x2i_choices: list[str] = [] + if train_dataloader_t2i is not None: + x2i_choices.append("t2i") + if train_dataloader_i2i is not None: + x2i_choices.append("i2i") + + if x2i_choices: + if accelerator.num_processes > 1: + local_choice = random.randrange(len(x2i_choices)) if accelerator.is_main_process else 0 + choice_idx = _broadcast_choice(local_choice, accelerator) + else: + choice_idx = random.randrange(len(x2i_choices)) + + if choice_idx < 0 or choice_idx >= len(x2i_choices): + if accelerator.is_main_process: + logger.warning( + "Received out-of-range x2i branch index %s for %s choices; clamping.", + choice_idx, + len(x2i_choices), + ) + choice_idx = choice_idx % len(x2i_choices) + + selected_x2i_branch = x2i_choices[choice_idx] + if selected_x2i_branch == "t2i": + t2i_batch = _next_from_t2i() + else: + i2i_batch = _next_from_i2i() + + # Synchronize skip decision across all ranks to avoid collective mismatches + required_flows = ["t2s_flow", "s2t_flow"] + if train_dataloader_lm is not None: + required_flows.append("lm_flow") + if train_dataloader_mmu is not None: + required_flows.append("mmu_flow") + if train_dataloader_s2s is not None: + required_flows.append("s2s_flow") + + local_skip = 0 + if selected_v2_branch == "v2t" and v2t_batch is None: + local_skip = 1 + elif selected_v2_branch == "v2s" and v2s_batch is None: + local_skip = 1 + else: + for key in required_flows: + if batch.get(key) is None: + local_skip = 1 + break + if selected_x2i_branch == "t2i": + if t2i_batch is None: + local_skip = 1 + else: + t2i_images = t2i_batch["t2i"].get("images") + if not isinstance(t2i_images, torch.Tensor) or t2i_images.shape[0] == 0: + local_skip = 1 + elif selected_x2i_branch == "i2i": + if i2i_batch is None: + local_skip = 1 + else: + i2i_sources = i2i_batch["i2i"].get("source_images") + i2i_targets = i2i_batch["i2i"].get("target_images") + if ( + not isinstance(i2i_sources, torch.Tensor) + or not isinstance(i2i_targets, torch.Tensor) + or i2i_sources.shape[0] == 0 + or i2i_targets.shape[0] == 0 + ): + local_skip = 1 + try: + skip_tensor = torch.tensor(local_skip, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction='sum') + should_skip = skip_sum.item() > 0 + except Exception: + # Fallback if reduce isn't available for any reason + should_skip = local_skip == 1 + + if should_skip: + if accelerator.is_main_process and local_skip: + logger.warning(f"Skipping step {global_step} (required multimodal batch missing) [synced]") + continue + + device = accelerator.device + batch_size_v2s = 0 + input_ids_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_v2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + # Text-to-image samples + batch_size_t2i = 0 + mask_prob = torch.tensor(0.0, device=device) + t2i_masks = torch.empty((0, 1), dtype=torch.long, device=device) + input_ids_t2i = torch.empty((0, 1), dtype=torch.long, device=device) + labels_t2i = torch.empty((0, 1), dtype=torch.long, device=device) + batch_size_i2i = 0 + mask_prob_i2i = torch.tensor(0.0, device=device) + input_ids_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + labels_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + attention_masks_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + + if selected_x2i_branch == "t2i" and t2i_batch is not None: + t2i_texts = t2i_batch["t2i"].get("texts", []) + t2i_images_tensor = t2i_batch["t2i"].get("images") + if isinstance(t2i_images_tensor, torch.Tensor) and t2i_images_tensor.shape[0] > 0: + t2i_images_tensor = t2i_images_tensor.to(device, non_blocking=True) + batch_size_t2i = t2i_images_tensor.shape[0] + ( + input_ids_t2i, + labels_t2i, + mask_prob, + _, + t2i_masks, + ) = prepare_inputs_and_labels(t2i_images_tensor, t2i_texts, config.training.min_masking_rate) + input_ids_t2i = input_ids_t2i.to(device, non_blocking=True) + labels_t2i = labels_t2i.to(device, non_blocking=True) + t2i_masks = t2i_masks.to(device, non_blocking=True) + if mask_prob.device != device: + mask_prob = mask_prob.to(device) + + if selected_x2i_branch == "i2i" and i2i_batch is not None: + i2i_prompts = i2i_batch["i2i"].get("prompts", []) + i2i_source_tensor = i2i_batch["i2i"].get("source_images") + i2i_target_tensor = i2i_batch["i2i"].get("target_images") + if ( + isinstance(i2i_source_tensor, torch.Tensor) + and isinstance(i2i_target_tensor, torch.Tensor) + and i2i_source_tensor.shape[0] > 0 + and i2i_target_tensor.shape[0] > 0 + ): + i2i_source_tensor = i2i_source_tensor.to(device, non_blocking=True) + i2i_target_tensor = i2i_target_tensor.to(device, non_blocking=True) + batch_size_i2i = i2i_source_tensor.shape[0] + ( + input_ids_i2i, + labels_i2i, + mask_prob_i2i, + attention_masks_i2i, + ) = prepare_inputs_and_labels_for_i2i( + i2i_source_tensor, + i2i_target_tensor, + i2i_prompts, + is_train=True, + ) + input_ids_i2i = input_ids_i2i.to(device, non_blocking=True) + labels_i2i = labels_i2i.to(device, non_blocking=True) + attention_masks_i2i = attention_masks_i2i.to(device, non_blocking=True) + if mask_prob_i2i.device != device: + mask_prob_i2i = mask_prob_i2i.to(device) + + # Language modeling samples + batch_size_lm = 0 + input_ids_lm = torch.empty((0, 1), dtype=torch.long, device=device) + labels_lm = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_lm = torch.empty((0, 1), dtype=torch.float32, device=device) + if train_dataloader_lm is not None: + lm_batch = batch.get("lm_flow") + if lm_batch is not None: + texts_lm = lm_batch["input_ids"] + batch_size_lm = len(texts_lm) + max_seq_for_lm = input_ids_t2i.shape[1] if batch_size_t2i > 0 else preproc_config.max_seq_length + input_ids_lm, labels_lm, p_mask_lm = prepare_inputs_and_labels_for_text(texts_lm, max_seq_for_lm) + input_ids_lm = input_ids_lm.to(device, non_blocking=True) + labels_lm = labels_lm.to(device, non_blocking=True) + p_mask_lm = p_mask_lm.to(device, non_blocking=True) + + if isinstance(v2t_batch, dict): + video_tensor_text_raw = v2t_batch.get("video") + texts_vid = v2t_batch.get("captions", []) + else: + video_tensor_text_raw = None + texts_vid = [] + + if isinstance(v2s_batch, dict): + video_tensor_speech_raw = v2s_batch.get("video") + speech_items = v2s_batch.get("speech", []) + else: + video_tensor_speech_raw = None + speech_items = [] + + batch_size_v2t = video_tensor_text_raw.shape[0] if isinstance(video_tensor_text_raw, torch.Tensor) else 0 + batch_size_v2s = len(speech_items) + + video_tensor_text = ( + video_tensor_text_raw.to(device, non_blocking=True) + if isinstance(video_tensor_text_raw, torch.Tensor) + else torch.empty((0, 1, 1, 1, 1), device=device) + ) + video_tensor_speech = ( + video_tensor_speech_raw.to(device, non_blocking=True) + if isinstance(video_tensor_speech_raw, torch.Tensor) + else torch.empty((0, 1, 1, 1, 1), device=device) + ) + + batch_size_t2s_text = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + + s2s_batch = batch.get("s2s_flow") + batch_size_s2s = 0 + if s2s_batch is not None: + batch_size_s2s = len(s2s_batch.get("emova_sft", [])) + len(s2s_batch.get("instructs2s", [])) + + mmu_batch = batch.get("mmu_flow") + batch_size_mmu = 0 + image_tensor_list = [] + texts_image = [] + if mmu_batch is not None: + image_tensor_list = mmu_batch.get("images", []) + texts_image = mmu_batch.get("text", []) + batch_size_mmu = len(image_tensor_list) + + s2t_flow = batch.get("s2t_flow", {}) + t2s_flow = batch.get("t2s_flow", {}) + audio_paths_s2t, texts_s2t = s2t_flow.get("audio_path", []), s2t_flow.get("text", []) + audio_paths_t2s, texts_t2s = t2s_flow.get("audio_path", []), t2s_flow.get("text", []) + audio_tokens_s2t = s2t_flow.get("audio_tokens", []) + audio_tokens_t2s = t2s_flow.get("audio_tokens", []) + + if batch_size_s2t > 0 and batch_size_t2s_text > 0: + if accelerator.num_processes > 1: + local_choice = 0 if accelerator.is_main_process and random.random() < 0.5 else 1 + drop_t2s = _broadcast_choice(local_choice, accelerator) == 0 + else: + drop_t2s = random.random() < 0.5 + + if drop_t2s: + audio_paths_t2s = [] + texts_t2s = [] + batch_size_t2s_text = 0 + else: + audio_paths_s2t = [] + texts_s2t = [] + batch_size_s2t = 0 + else: + batch_size_s2t = len(audio_paths_s2t) + batch_size_t2s_text = len(audio_paths_t2s) + + active_x2i_branch = selected_x2i_branch or "none" + logger.info( + f"x2i_branch: {active_x2i_branch}, batch_size_t2i: {batch_size_t2i}, batch_size_i2i: {batch_size_i2i}, batch_size_lm: {batch_size_lm}, " + f"batch_size_v2t: {batch_size_v2t}, batch_size_v2s: {batch_size_v2s}, batch_size_t2s: {batch_size_t2s_text}, " + f"batch_size_s2t: {batch_size_s2t}, batch_size_s2s: {batch_size_s2s}, batch_size_mmu: {batch_size_mmu}" + ) + offset = speech_vocab_start + + data_time_m.update(time.time() - end) + + failure_messages = [] + step_failed = False + + input_ids_vid = torch.empty((0, 1), dtype=torch.long, device=device) + labels_vid = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_vid = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_vid = torch.empty((0, 1), dtype=torch.long, device=device) + + input_ids_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_v2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + + if batch_size_v2t > 0: + video_token_list = [] + for vid_idx, video in enumerate(video_tensor_text): + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens + len(uni_prompting.text_tokenizer) + video_token_list.append(tokens.view(-1)) + + if not step_failed and video_token_list: + video_tokens_text = torch.stack(video_token_list, dim=0) + + texts_with_prompt: List[str] + is_vid_inst = False + if texts_vid and isinstance(texts_vid[0], (list, tuple)) and isinstance(texts_vid[0][0], dict): + is_vid_inst = True + vid_inst_prompt: List[str] = [] + vid_inst_answer: List[str] = [] + for conv in texts_vid: + human_msg = "" + assistant_msg = "" + for turn in conv: + role = turn.get("from") + value = turn.get("value", "") + if role == "human": + human_msg = value.replace("\n", "") + elif role == "gpt": + assistant_msg = value + vid_inst_prompt.append(human_msg) + vid_inst_answer.append(assistant_msg) + texts_with_prompt = [ + "<|start_header_id|>user<|end_header_id|>\n" + f"{vid_inst_prompt[i]}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n" + f"{vid_inst_answer[i]}" + for i in range(len(vid_inst_answer)) + ] + else: + prompt_v2t_selected = random.choice(V2T_INSTRUCTION) + texts_with_prompt = [ + "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt_v2t_selected}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n" + f"{text if isinstance(text, str) else str(text)}" + for text in texts_vid + ] + + input_ids_vid_tmp, prompt_masks_vid, labels_vid_tmp = uni_prompting((video_tokens_text, texts_with_prompt), 'v2t') + input_ids_vid_tmp, labels_vid_tmp, p_mask_vid, answer_lengths_vid = prepare_inputs_and_labels_for_mmu( + input_ids_vid_tmp, prompt_masks_vid, labels_vid_tmp + ) + input_ids_vid = input_ids_vid_tmp.to(device, non_blocking=True) + labels_vid = labels_vid_tmp.to(device, non_blocking=True) + p_mask_vid = p_mask_vid.to(device, non_blocking=True) + answer_lengths_vid = answer_lengths_vid.to(device, non_blocking=True) + else: + batch_size_v2t = 0 + + if batch_size_v2s > 0 and not step_failed: + all_audio_tokens: list[torch.Tensor] = [] + for speech_entry in speech_items: + if isinstance(speech_entry, torch.Tensor): + tokens = speech_entry.to(device, non_blocking=True) + else: + tokens, err = safe_audio_encode(speech_entry, "v2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + video_token_list_v2s: list[torch.Tensor] = [] + if not step_failed: + for vid_idx, video in enumerate(video_tensor_speech): + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens + len(uni_prompting.text_tokenizer) + video_token_list_v2s.append(tokens.view(-1)) + + if not step_failed and all_audio_tokens and video_token_list_v2s: + video_tokens_v2s = torch.stack(video_token_list_v2s, dim=0) + prompts_v2s = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(V2S_INSTRUCTION)}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for _ in range(batch_size_v2s) + ] + input_ids_v2s_tmp, prompt_masks_v2s, labels_v2s_tmp = uni_prompting( + (video_tokens_v2s, prompts_v2s, all_audio_tokens), 'v2s_ip' + ) + input_ids_v2s_tmp, labels_v2s_tmp, p_mask_v2s_tmp, answer_lengths_v2s_tmp = prepare_inputs_and_labels_for_v2s( + input_ids_v2s_tmp, prompt_masks_v2s, labels_v2s_tmp + ) + input_ids_v2s = input_ids_v2s_tmp.to(device, non_blocking=True) + labels_v2s = labels_v2s_tmp.to(device, non_blocking=True) + p_mask_v2s = p_mask_v2s_tmp.to(device, non_blocking=True) + answer_lengths_v2s = answer_lengths_v2s_tmp.to(device, non_blocking=True) + else: + batch_size_v2s = 0 + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed and batch_size_s2t > 0: + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + if not audio_tokens_s2t: + audio_tokens_s2t = [None] * len(audio_paths_s2t) + elif len(audio_tokens_s2t) < len(audio_paths_s2t): + audio_tokens_s2t = list(audio_tokens_s2t) + [None] * (len(audio_paths_s2t) - len(audio_tokens_s2t)) + + for path, cached_tokens in zip(audio_paths_s2t, audio_tokens_s2t): + source = cached_tokens if cached_tokens is not None else path + tokens, err = safe_audio_encode(source, "s2t") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + # Preserve trailing EOS tokens in s2t targets for explicit prediction. + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + else: + input_ids_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_s2t = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed and batch_size_t2s_text > 0: + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + if not audio_tokens_t2s: + audio_tokens_t2s = [None] * len(audio_paths_t2s) + elif len(audio_tokens_t2s) < len(audio_paths_t2s): + audio_tokens_t2s = list(audio_tokens_t2s) + [None] * (len(audio_paths_t2s) - len(audio_tokens_t2s)) + + for path, cached_tokens in zip(audio_paths_t2s, audio_tokens_t2s): + source = cached_tokens if cached_tokens is not None else path + tokens, err = safe_audio_encode(source, "t2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + # Chat-style instruction formatting for T2S training + prompt = random.choice(prompt_t2s) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in texts_t2s + ] + + # input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + else: + input_ids_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_t2s = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + + audio_user_ids_s2s: list[torch.Tensor] = [] + audio_asst_ids_s2s: list[torch.Tensor] = [] + image_token_blocks_s2s: list[Optional[torch.Tensor]] = [] + input_ids_s2s = None + labels_s2s = None + p_mask_s2s = None + answer_lengths_s2s = None + + if not step_failed and batch_size_s2s > 0 and s2s_batch is not None: + s2s_sample_counter = 0 + + emova_samples = s2s_batch.get("emova_sft", []) + for sample_idx, (usr_ids, asst_ids, image_like) in enumerate(emova_samples): + usr_tensor = torch.tensor(usr_ids, dtype=torch.long, device=accelerator.device).unsqueeze(0) + asst_tensor = torch.tensor(asst_ids, dtype=torch.long, device=accelerator.device).unsqueeze(0) + + audio_user_ids_s2s.append(usr_tensor + offset) + audio_asst_ids_s2s.append(asst_tensor + offset) + + decoded_payload = maybe_decode_image(image_like) + if isinstance(decoded_payload, (list, tuple)): + token_payloads = [] + for local_img in decoded_payload: + if local_img is None: + continue + pixel_values = image_transform(local_img, resolution=preproc_config.resolution).to(accelerator.device) + if pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + image_tokens_raw, err = safe_image_get_code(pixel_values, s2s_sample_counter) + if err is not None: + failure_messages.append(err) + step_failed = True + break + image_tokens = image_tokens_raw + len(uni_prompting.text_tokenizer) + token_payloads.append(image_tokens.view(1, -1).to(accelerator.device, non_blocking=True)) + s2s_sample_counter += 1 + if step_failed: + break + image_token_blocks_s2s.append(token_payloads if token_payloads else None) + else: + if decoded_payload is not None: + pixel_values = image_transform(decoded_payload, resolution=preproc_config.resolution).to(accelerator.device) + if pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + image_tokens_raw, err = safe_image_get_code(pixel_values, s2s_sample_counter) + if err is not None: + failure_messages.append(err) + step_failed = True + break + image_tokens = image_tokens_raw + len(uni_prompting.text_tokenizer) + image_token_blocks_s2s.append(image_tokens.view(1, -1).to(accelerator.device, non_blocking=True)) + s2s_sample_counter += 1 + else: + image_token_blocks_s2s.append(None) + + instruct_samples = [] if step_failed else s2s_batch.get("instructs2s", []) + if not step_failed: + for sample_idx, (usr_wav, asst_wav, _) in enumerate(instruct_samples): + user_tokens, err_usr = safe_audio_encode(usr_wav, "s2s-user") + if err_usr is not None: + failure_messages.append(err_usr) + step_failed = True + break + asst_tokens, err_asst = safe_audio_encode(asst_wav, "s2s-assistant") + if err_asst is not None: + failure_messages.append(err_asst) + step_failed = True + break + + user_tokens = user_tokens.to(accelerator.device, non_blocking=True) + asst_tokens = asst_tokens.to(accelerator.device, non_blocking=True) + if user_tokens.size(-1) > config.dataset.preprocessing.max_aud_length: + duration_tokens = config.dataset.preprocessing.max_aud_length + user_tokens = user_tokens[..., :duration_tokens] + if asst_tokens.size(-1) > config.dataset.preprocessing.max_aud_length: + duration_tokens = config.dataset.preprocessing.max_aud_length + asst_tokens = asst_tokens[..., :duration_tokens] + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + if asst_tokens.dim() == 1: + asst_tokens = asst_tokens.unsqueeze(0) + audio_user_ids_s2s.append(user_tokens + offset) + audio_asst_ids_s2s.append(asst_tokens + offset) + image_token_blocks_s2s.append(None) + s2s_sample_counter += 1 + + if not step_failed and audio_user_ids_s2s: + input_ids_s2s, prompt_masks_s2s, labels_s2s = uni_prompting( + (audio_user_ids_s2s, audio_asst_ids_s2s, image_token_blocks_s2s), + 's2s_ip' + ) + + input_ids_s2s, labels_s2s, p_mask_s2s, answer_lengths_s2s = prepare_inputs_and_labels_for_s2s( + input_ids_s2s, + prompt_masks_s2s, + labels_s2s, + ) + + if ( + answer_lengths_s2s is not None + and answer_lengths_s2s.numel() > 0 + and accelerator.is_main_process + ): + per_sample_lengths = answer_lengths_s2s[:, 0].detach().cpu() + lengths_list = [int(length) for length in per_sample_lengths.tolist()] + stats_msg = ( + f"min={int(per_sample_lengths.min().item())}, " + f"max={int(per_sample_lengths.max().item())}, " + f"mean={per_sample_lengths.float().mean().item():.2f}" + ) + logger.info("S2S answer lengths (no pad): %s | %s", lengths_list, stats_msg) + + if input_ids_s2s is None: + device = accelerator.device + input_ids_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_s2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + + input_ids_mmu = None + labels_mmu = None + p_mask_mmu = None + answer_lengths_mmu = None + + if not step_failed and batch_size_mmu > 0: + batch_image_ids_list = [] + batch_text_ids = [] + + for b_idx, image_list in enumerate(image_tensor_list): + per_img_ids = [] + for j, img in enumerate(image_list): + tok, err = safe_image_get_code( + img.to(accelerator.device, non_blocking=True), + sample_index=j + ) + if err is not None: + failure_messages.append(err) + step_failed = True + break + + tok = tok.to(accelerator.device, non_blocking=True).view(-1).long() + tok = tok + len(uni_prompting.text_tokenizer) + per_img_ids.append(tok) + + if step_failed: + break + + batch_image_ids_list.append(per_img_ids) + text_ids = uni_prompting.text_tokenizer.encode(texts_image[b_idx], add_special_tokens=False) + batch_text_ids.append(text_ids) + + if not step_failed: + input_ids_mmu, prompt_masks_mmu, labels_mmu = uni_prompting.mmu_mult_prompt( + batch_image_ids_list=batch_image_ids_list, + batch_text_ids=batch_text_ids, + ) + + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths_mmu + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks_mmu, labels_mmu) + + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + labels_mmu = labels_mmu.to(accelerator.device, non_blocking=True) + p_mask_mmu = p_mask_mmu.to(accelerator.device, non_blocking=True) + answer_lengths_mmu = answer_lengths_mmu.to(accelerator.device, non_blocking=True) + + if batch_size_mmu == 0 or input_ids_mmu is None: + input_ids_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_mmu = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + if not step_failed: + total_batch_size_t2s = batch_size_t2s_text + else: + total_batch_size_t2s = batch_size_t2s_text + + failure_tensor = torch.tensor(1 if step_failed else 0, device=accelerator.device, dtype=torch.int32) + failure_sum = accelerator.reduce(failure_tensor, reduction='sum') + if failure_sum.item() > 0: + if accelerator.is_main_process and failure_messages: + for msg in failure_messages: + logger.warning(f"Skipping global step {global_step} due to preprocessing failure: {msg}") + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + # -------------------------------------------------------------------------------- + # for name, tensor in [ + # ("t2i", input_ids_t2i), + # ("i2i", input_ids_i2i), + # ("lm", input_ids_lm), + # ("mmu", input_ids_mmu), + # ("vid", input_ids_vid), + # ("s2t", input_ids_s2t), + # ("s2s", input_ids_s2s), + # ("t2s", input_ids_t2s), + # ]: + # if tensor is not None: + # print(f"{name:>4}: shape={getattr(tensor, 'shape', None)}, len={len(tensor) if hasattr(tensor, '__len__') else 'N/A'}") + + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + seq_lengths = [] + if input_ids_t2i.shape[0] > 0: + seq_lengths.append(input_ids_t2i.shape[1]) + if input_ids_i2i.shape[0] > 0: + seq_lengths.append(input_ids_i2i.shape[1]) + if input_ids_lm.shape[0] > 0: + seq_lengths.append(input_ids_lm.shape[1]) + seq_lengths.extend([ + input_ids_vid.shape[1], + input_ids_v2s.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1], + ]) + if input_ids_s2s.shape[0] > 0: + seq_lengths.append(input_ids_s2s.shape[1]) + if input_ids_mmu.shape[0] > 0: + seq_lengths.append(input_ids_mmu.shape[1]) + max_len = max(seq_lengths) + + # 3. Pad all tensors to the max_len + input_ids_t2i = pad_tensor(input_ids_t2i, max_len, pad_token_id) + labels_t2i = pad_tensor(labels_t2i, max_len, -100) + if t2i_masks.shape[0] > 0: + t2i_masks = pad_tensor(t2i_masks.long(), max_len, 0) + else: + t2i_masks = torch.empty((0, max_len), dtype=torch.long, device=device) + + input_ids_i2i = pad_tensor(input_ids_i2i, max_len, pad_token_id) + labels_i2i = pad_tensor(labels_i2i, max_len, -100) + if attention_masks_i2i.shape[0] > 0: + attention_masks_i2i = pad_tensor(attention_masks_i2i.long(), max_len, 0) + else: + attention_masks_i2i = torch.empty((0, max_len), dtype=torch.long, device=device) + + + input_ids_lm = pad_tensor(input_ids_lm, max_len, pad_token_id) + labels_lm = pad_tensor(labels_lm, max_len, -100) + p_mask_lm = pad_tensor(p_mask_lm, max_len, 1.0) + + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_v2s = pad_tensor(input_ids_v2s, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + input_ids_s2s = pad_tensor(input_ids_s2s, max_len, pad_token_id) + input_ids_mmu = pad_tensor(input_ids_mmu, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_v2s = pad_tensor(labels_v2s, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + labels_s2s = pad_tensor(labels_s2s, max_len, -100) + labels_mmu = pad_tensor(labels_mmu, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_v2s = pad_tensor(p_mask_v2s, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + p_mask_s2s = pad_tensor(p_mask_s2s, max_len, 1.0) + p_mask_mmu = pad_tensor(p_mask_mmu, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_v2s = pad_answer_lengths(answer_lengths_v2s, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + answer_lengths_s2s = pad_answer_lengths(answer_lengths_s2s, max_len) + answer_lengths_mmu = pad_answer_lengths(answer_lengths_mmu, max_len) + + input_ids = torch.cat(( + input_ids_t2i, + input_ids_i2i, + input_ids_lm, + input_ids_mmu, + input_ids_vid, + input_ids_v2s, + input_ids_s2t, + input_ids_s2s, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_t2i, + labels_i2i, + labels_lm, + labels_mmu, + labels_vid, + labels_v2s, + labels_s2t, + labels_s2s, + labels_t2s + ), dim=0) + + # w/o texts and images + if batch_size_lm == 0: + p_mask_lm = torch.empty((0, max_len), dtype=torch.float32, device=device) + if batch_size_t2i == 0 and t2i_masks.shape[0] == 0: + t2i_masks = torch.empty((0, max_len), dtype=torch.long, device=device) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + logger.info("Input ids shape: {}".format(input_ids.shape)) + # with accelerator.accumulate(model): + logits, loss_t2i, loss_i2i, loss_lm, loss_mmu, loss_vid, loss_v2s, loss_s2t, loss_s2s, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_i2i=batch_size_i2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_v2s=batch_size_v2s, + batch_size_s2t=batch_size_s2t, + batch_size_s2s=batch_size_s2s, + batch_size_t2s=total_batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + attention_masks_i2i=attention_masks_i2i, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_v2s=p_mask_v2s, + p_mask_s2t=p_mask_s2t, + p_mask_s2s=p_mask_s2s, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_v2s=answer_lengths_v2s, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_s2s=answer_lengths_s2s, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + text_vocab_size_override=len(uni_prompting.text_tokenizer), + ) + + if batch_size_t2i == 0: + loss_t2i = loss_t2i.new_zeros(()) + if batch_size_i2i == 0: + loss_i2i = loss_i2i.new_zeros(()) + + # Gather the losses across all processes for logging (use reduce to avoid shape mismatches) + avg_loss_t2i = accelerator.reduce(loss_t2i, reduction='mean') + avg_loss_i2i = accelerator.reduce(loss_i2i, reduction='mean') + avg_loss_lm = accelerator.reduce(loss_lm, reduction='mean') + avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + avg_loss_vid = accelerator.reduce(loss_vid, reduction='mean') + avg_loss_v2s = accelerator.reduce(loss_v2s, reduction='mean') + avg_loss_s2t = accelerator.reduce(loss_s2t, reduction='mean') + avg_loss_s2s = accelerator.reduce(loss_s2s, reduction='mean') + if not torch.isfinite(loss_t2s): + if labels_t2s.numel() > 0: + speech_vocab_end = speech_vocab_start + audio_codebook_size + valid_mask = labels_t2s != -100 + if valid_mask.any(): + labels_valid = labels_t2s[valid_mask] + below_count = (labels_valid < speech_vocab_start).sum().item() + above_count = (labels_valid >= speech_vocab_end).sum().item() + labels_min = labels_valid.min().item() + labels_max = labels_valid.max().item() + else: + below_count = above_count = 0 + labels_min = labels_max = -100 + p_mask_min = p_mask_t2s.min().item() if p_mask_t2s.numel() > 0 else float("nan") + ans_len_min = ( + answer_lengths_t2s.min().item() + if answer_lengths_t2s.numel() > 0 + else float("nan") + ) + accelerator.print( + "[t2s NaN debug] " + f"rank={accelerator.process_index} step={global_step} " + f"slice=({speech_vocab_start}, {speech_vocab_end}) " + f"labels_min={labels_min} labels_max={labels_max} " + f"below_slice={below_count} above_slice={above_count} " + f"p_mask_min={p_mask_min} answer_len_min={ans_len_min}" + ) + accelerator.print( + f"[rank {accelerator.process_index}] t2s loss became NaN/Inf at global step {global_step} " + f"(local value: {loss_t2s.item()})" + ) + logger.warning( + "[rank %s] t2s loss became NaN/Inf at global step %s (local value: %s)", + accelerator.process_index, + global_step, + loss_t2s.item(), + ) + avg_loss_t2s = accelerator.reduce(loss_t2s, reduction='mean') + if not torch.isfinite(avg_loss_t2s): + accelerator.print( + f"[rank {accelerator.process_index}] reduced t2s loss NaN/Inf at global step {global_step} " + f"(value after all-reduce: {avg_loss_t2s.item()})" + ) + if accelerator.is_main_process: + logger.warning( + "Reduced t2s loss became NaN/Inf at global step %s (value after all-reduce: %s)", + global_step, + avg_loss_t2s.item(), + ) + + mmu_coeff = getattr(config.training, "mmu_coeff", 0.0) + i2i_coeff = getattr(config.training, "i2i_coeff", config.training.t2i_coeff) + s2s_coeff = getattr(config.training, "s2s_coeff", config.training.t2s_coeff) + v2s_coeff = getattr(config.training, "v2s_coeff", config.training.t2s_coeff) + loss = ( + config.training.t2i_coeff * loss_t2i + + i2i_coeff * loss_i2i + + config.training.lm_coeff * loss_lm + + mmu_coeff * loss_mmu + + config.training.v2t_coeff * loss_vid + + v2s_coeff * loss_v2s + + config.training.s2t_coeff * loss_s2t + + s2s_coeff * loss_s2s + + config.training.t2s_coeff * loss_t2s + ) + + if batch_size_t2i > 0: + local_masking_rate = mask_prob.float().mean() + else: + local_masking_rate = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate = accelerator.reduce(local_masking_rate, reduction='mean') + + if batch_size_i2i > 0: + local_masking_rate_i2i = mask_prob_i2i.float().mean() + else: + local_masking_rate_i2i = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_i2i = accelerator.reduce(local_masking_rate_i2i, reduction='mean') + + if batch_size_s2s > 0 and p_mask_s2s.numel() > 0: + local_masking_rate_s2s = p_mask_s2s.float().mean() + else: + local_masking_rate_s2s = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_s2s = accelerator.reduce(local_masking_rate_s2s, reduction='mean') + + if batch_size_v2s > 0 and p_mask_v2s.numel() > 0: + local_masking_rate_v2s = p_mask_v2s.float().mean() + else: + local_masking_rate_v2s = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_v2s = accelerator.reduce(local_masking_rate_v2s, reduction='mean') + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "avg_masking_rate_i2i": avg_masking_rate_i2i.item(), + "avg_masking_rate_v2s": avg_masking_rate_v2s.item(), + "avg_masking_rate_s2s": avg_masking_rate_s2s.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + + loss_entries = [ + ("step_loss_t2i", avg_loss_t2i), + ("step_loss_i2i", avg_loss_i2i), + ("step_loss_lm", avg_loss_lm), + ("step_loss_mmu", avg_loss_mmu), + ("step_loss_vid", avg_loss_vid), + ("step_loss_v2s", avg_loss_v2s), + ("step_loss_s2t", avg_loss_s2t), + ("step_loss_s2s", avg_loss_s2s), + ("step_loss_t2s", avg_loss_t2s), + ] + + loss_log_parts = [] + for key, value in loss_entries: + loss_value = value.item() + if loss_value != 0.0: + logs[key] = loss_value + loss_log_parts.append(f"{key.replace('step_', '').capitalize()}: {loss_value:0.4f}") + + accelerator.log(logs, step=global_step + 1) + + loss_str = " ".join(loss_log_parts) + logger.info( + "Step: %d %s Data (t): %.4f, %.2f/s/gpu Batch (t): %.4f LR: %.6f" + % ( + global_step + 1, + loss_str, + data_time_m.val, + samples_per_second_per_gpu, + batch_time_m.val, + lr_scheduler.get_last_lr()[0], + ) + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if global_step == 0 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + # if (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_omada_inst2.py b/MMaDA/training/train_omada_inst2.py new file mode 100644 index 0000000000000000000000000000000000000000..b91840504faf40ec579241ea4e86188e7f759b48 --- /dev/null +++ b/MMaDA/training/train_omada_inst2.py @@ -0,0 +1,4282 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import binascii +import hashlib +import os +import sys +import warnings +import subprocess +import tempfile +os.environ["FFMPEG_LOG_LEVEL"] = "error" +warnings.filterwarnings("ignore") + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import torch.nn.functional as F +import shutil +import time +import cv2 +import glob +import random +from itertools import zip_longest +import contextlib +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union, Dict, Any, List, Iterator +from collections.abc import Sequence +import csv +import numpy as np +from PIL import Image +from io import BytesIO +from omegaconf import OmegaConf, DictConfig +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader +import multiprocessing as py_mp +import torch.multiprocessing as mp + +try: + cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR) +except AttributeError: + warnings.filterwarnings("ignore", category=FutureWarning) + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import ( + SpeechTextDataset, + MixedSpeechTextDataset, + Speech2SpeechDataset, + TextImageInterleavedDataset, + load_video_mp4, + VideoCaptionDataset, + VideoSpeechDataset, + S2T_INSTRUCTION, + T2S_INSTRUCTION, + V2S_INSTRUCTION, + s2s_collate_fn, +) +# import librosa + +from training.data import ( + Text2ImageDataset, + HQEditX2IDataset, + CombinedX2IDataset, + HFInstructionTextDataset, + TextToImage2MDataset, + OpenImageI2IDataset, +) +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +cv2.setNumThreads(0) +torch.set_num_threads(1) +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["NCCL_ASYNC_ERROR_HANDLING"]= "1" +os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" +os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]= "1" +os.environ["TORCH_NCCL_BLOCKING_WAIT"]= "1" + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +def _configure_multiprocessing() -> None: + try: + mp.set_sharing_strategy("file_descriptor") + except RuntimeError as exc: + logger.warning("Failed to set multiprocessing sharing strategy to 'file_descriptor': %s", exc) + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + 'audio_tokens': [item.get('audio_tokens') for item in batch], + } + + +def _empty_audio_batch() -> dict[str, list[Any]]: + """Utility to create an empty speech batch placeholder.""" + return { + "audio_path": [], + "text": [], + "audio_tokens": [], + } + + +def collate_fn_mmu_mult(batch): + return { + 'images': [item['images'] for item in batch], + 'text': [item['text'] for item in batch], + } + + +def collate_fn_x2i(batch): + t2i_texts: list[str] = [] + t2i_images: list[torch.Tensor] = [] + + i2i_prompts: list[str] = [] + i2i_source_images: list[torch.Tensor] = [] + i2i_target_images: list[torch.Tensor] = [] + + ref_image: Optional[torch.Tensor] = None + + has_i2i_sample = False + + for sample in batch: + input_prompt = sample.get("input_prompt") + output_prompt = sample.get("output_prompt") + edit_prompt = sample.get("edit_prompt") + inverse_prompt = sample.get("inverse_prompt") + input_image = sample.get("input_image") + output_image = sample.get("output_image") + + if isinstance(input_image, torch.Tensor) and ref_image is None: + ref_image = input_image + if isinstance(output_image, torch.Tensor) and ref_image is None: + ref_image = output_image + + has_edit_pair = ( + isinstance(input_image, torch.Tensor) + and isinstance(output_image, torch.Tensor) + and ( + (edit_prompt and edit_prompt.strip()) + or (inverse_prompt and inverse_prompt.strip()) + ) + ) + + if has_edit_pair: + has_i2i_sample = True + edit_candidates: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + if edit_prompt and edit_prompt.strip(): + edit_candidates.append((edit_prompt, input_image, output_image)) + if inverse_prompt and inverse_prompt.strip(): + edit_candidates.append((inverse_prompt, output_image, input_image)) + + if edit_candidates: + chosen_prompt, chosen_src, chosen_tgt = random.choice(edit_candidates) + i2i_prompts.append(chosen_prompt) + i2i_source_images.append(chosen_src) + i2i_target_images.append(chosen_tgt) + continue + else: + if input_prompt and isinstance(input_image, torch.Tensor): + t2i_texts.append(input_prompt) + t2i_images.append(input_image) + elif output_prompt and isinstance(output_image, torch.Tensor): + t2i_texts.append(output_prompt) + t2i_images.append(output_image) + + if has_i2i_sample: + # i2iź°€ ķ•˜ė‚˜ė¼ė„ ģžˆģœ¼ė©“ ģ“ė²ˆ ė°°ģ¹˜ėŠ” i2i ģ „ģš©ģœ¼ė”œ ģ‚¬ģš©ķ•˜ź³  t2iėŠ” 비움 + t2i_texts = [] + t2i_images = [] + + def stack_images(images: list[torch.Tensor]) -> torch.Tensor: + if images: + return torch.stack(images, dim=0) + if ref_image is not None: + c, h, w = ref_image.shape[-3:] + return torch.empty((0, c, h, w), dtype=ref_image.dtype) + return torch.empty((0, 3, 0, 0), dtype=torch.float32) + + return { + "t2i": { + "texts": t2i_texts, + "images": stack_images(t2i_images), + }, + "i2i": { + "prompts": i2i_prompts, + "source_images": stack_images(i2i_source_images), + "target_images": stack_images(i2i_target_images), + }, + } + + +def collate_fn_v2t(batch: list[dict[str, Any]]) -> Optional[dict[str, Any]]: + filtered = [sample for sample in batch if sample is not None] + if not filtered: + return None + video_tensors: list[torch.Tensor] = [] + captions: list[Any] = [] + for sample in filtered: + frames = sample.get("video") + if frames is None: + continue + frame_tensor = torch.stack(frames, dim=0) + video_tensors.append(frame_tensor) + captions.append(sample.get("caption")) + if not video_tensors: + return None + return { + "video": torch.stack(video_tensors, dim=0), + "captions": captions, + } + + +def collate_fn_v2s(batch: list[dict[str, Any]]) -> Optional[dict[str, Any]]: + filtered = [sample for sample in batch if sample is not None] + if not filtered: + return None + video_tensors: list[torch.Tensor] = [] + speech_entries: list[Any] = [] + for sample in filtered: + frames = sample.get("video") + speech_value = sample.get("speech") + if frames is None or speech_value is None: + continue + frame_tensor = torch.stack(frames, dim=0) + video_tensors.append(frame_tensor) + speech_entries.append(speech_value) + if not video_tensors: + return None + return { + "video": torch.stack(video_tensors, dim=0), + "speech": speech_entries, + } + +def collate_fn_video_multimodal(batch): + text_videos: list[torch.Tensor] = [] + text_captions: list = [] + speech_videos: list[torch.Tensor] = [] + speech_entries: list[Any] = [] + + for sample in batch: + if sample is None: + continue + text_sample = sample.get("text") + if isinstance(text_sample, dict) and text_sample.get("video") is not None: + frames = text_sample["video"] + frame_tensor = torch.stack(frames, dim=0) + text_videos.append(frame_tensor) + text_captions.append(text_sample["caption"]) + + speech_sample = sample.get("speech") + if isinstance(speech_sample, dict) and speech_sample.get("video") is not None: + frames = speech_sample["video"] + frame_tensor = torch.stack(frames, dim=0) + speech_videos.append(frame_tensor) + speech_entries.append(speech_sample["speech"]) + + output: Dict[str, Any] = {} + if text_videos: + output["text"] = { + "video": torch.stack(text_videos, dim=0), + "captions": text_captions, + } + if speech_videos: + output["speech"] = { + "video": torch.stack(speech_videos, dim=0), + "speech": speech_entries, + } + if not output: + return None + return output + + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + for item in batch: + audio_entry = item['audio_path'] + if isinstance(audio_entry, torch.Tensor): + tokens = audio_entry.cpu() + else: + tokens = vq_model_audio.encode(audio_entry).cpu() + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + +@ torch.no_grad() +def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., + cfg_scale=0., remasking='low_confidence', mask_id=126336, attention_mask=None): + ''' + Args: + model: Mask predictor. + prompt: A tensor of shape (B, L), where B is batch size. + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + ''' + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + batch_size = prompt.shape[0] + x = torch.full((batch_size, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x[:, :prompt.shape[1]] = prompt.clone() + + prompt_index = (x != mask_id) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = model(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = model(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + # print(confidence.shape) + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + return x + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + +def _resolve_mask_schedule(config): + schedule_cfg = getattr(config, "mask_schedule", None) + if isinstance(schedule_cfg, DictConfig): + schedule_name = getattr(schedule_cfg, "schedule", None) + params_cfg = getattr(schedule_cfg, "params", None) + elif isinstance(schedule_cfg, dict): + schedule_name = schedule_cfg.get("schedule") + params_cfg = schedule_cfg.get("params") + else: + schedule_name = None + params_cfg = None + if schedule_name is None: + schedule_name = config.training.get("mask_schedule", "cosine") + params = {} + if params_cfg is not None: + if isinstance(params_cfg, DictConfig): + params = OmegaConf.to_container(params_cfg, resolve=True) or {} + elif isinstance(params_cfg, dict): + params = dict(params_cfg) + else: + params = params_cfg + if not isinstance(params, dict): + params = {} + return get_mask_schedule(schedule_name, **params) +def _tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image: + image = torch.clamp((image_tensor.detach().cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) + array = (image.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) + return Image.fromarray(array) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2I EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +@torch.no_grad() +def evaluate_t2i(model, vq_model_image, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2I Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + prompts_file = "/home/work/AIDAS/MMaDA/validation_prompts/quantative.txt" + if not prompts_file: + logger.warning("No validation prompts file configured. Skipping T2I evaluation.") + return + prompts_path = Path(prompts_file) + if not prompts_path.is_absolute(): + prompts_path = Path.cwd() / prompts_path + if not prompts_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / prompts_file + if alt_path.exists(): + prompts_path = alt_path + try: + with open(prompts_path, "r", encoding="utf-8") as handle: + prompts = [line.strip() for line in handle if line.strip()] + except OSError as exc: + logger.warning(f"Failed to read validation prompts from '{prompts_path}': {exc}. Skipping T2I evaluation.") + return + if not prompts: + logger.warning("Validation prompts file is empty. Skipping T2I evaluation.") + return + max_samples = getattr(config.experiment, "eval_num_t2i_samples", 8) + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 8 + prompts = prompts[:max_samples] + mask_schedule = _resolve_mask_schedule(config) + mask_token_id = unwrapped_model.config.mask_token_id + seq_len = getattr(getattr(config.model, "omada", None), "num_vq_tokens", None) + if seq_len is None: + seq_len = getattr(unwrapped_model.config, "num_vq_tokens", None) + if seq_len is None: + logger.warning("Unable to determine image token sequence length. Skipping T2I evaluation.") + return + seq_len = int(seq_len) + device = accelerator.device + image_tokens = torch.full((len(prompts), seq_len), mask_token_id, dtype=torch.long, device=device) + input_ids, attention_mask = uni_prompting((prompts, image_tokens), 't2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids, uncond_attention_mask = None, None + cfg_scale = 0.0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + use_autocast = accelerator.device.type == "cuda" and accelerator.mixed_precision != "no" + autocast_ctx = torch.autocast("cuda", dtype=weight_dtype) if use_autocast else contextlib.nullcontext() + with autocast_ctx: + gen_token_ids = unwrapped_model.t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=3.5, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=15, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=seq_len, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, min=0, max=unwrapped_model.config.codebook_size - 1) + images = vq_model_image.decode_code(gen_token_ids) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images = images.permute(0, 2, 3, 1).cpu().numpy() * 255.0 + pil_images = [Image.fromarray(img.astype(np.uint8)) for img in images] + wandb_images = [wandb.Image(img, caption=prompt) for img, prompt in zip(pil_images, prompts)] + accelerator.log({"eval/t2i_samples": wandb_images}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ I2I EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +@torch.no_grad() +def evaluate_i2i(model, vq_model_image, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running I2I Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + dataset_cfg_raw = getattr(config.dataset, "params", {}) + resolution = 512 + + def _cfg_to_dict(cfg): + if cfg is None: + return None + if isinstance(cfg, dict): + return cfg + if isinstance(cfg, DictConfig): + return OmegaConf.to_container(cfg, resolve=True) + return cfg + + dataset_cfg = _cfg_to_dict(dataset_cfg_raw) or {} + + eval_datasets: list[Dataset] = [] + eval_source_names: list[str] = [] + + # HQ-Edit evaluation dataset (always attempt; mirrors training) + try: + hqedit_split = dataset_cfg.get("hqedit_split", "train") + hqedit_eval = HQEditX2IDataset(split=hqedit_split, resolution=resolution) + if len(hqedit_eval) > 0: + eval_datasets.append(hqedit_eval) + eval_source_names.append(f"HQ-Edit[{hqedit_split}]") + else: + logger.warning("HQ-Edit evaluation split '%s' is empty; skipping.", hqedit_split) + except Exception as exc: + logger.warning("Failed to build HQ-Edit evaluation dataset: %s", exc) + + # OpenImage evaluation dataset if configured + openimage_cfg = _cfg_to_dict(dataset_cfg.get("openimage_i2i")) + if openimage_cfg: + try: + openimage_eval = OpenImageI2IDataset( + resolution=resolution, + image_root=openimage_cfg.get("image_root"), + sft_jsonl=openimage_cfg.get("sft_jsonl"), + pref_jsonl=openimage_cfg.get("pref_jsonl"), + multi_turn_jsonl=openimage_cfg.get("multi_turn_jsonl"), + prefer_summarized_text=bool(openimage_cfg.get("prefer_summarized_text", True)), + pref_positive_only=bool(openimage_cfg.get("pref_positive_only", True)), + skip_missing=bool(openimage_cfg.get("skip_missing", True)), + max_samples_per_source=openimage_cfg.get("max_samples_per_source"), + max_total_samples=openimage_cfg.get("max_total_samples"), + seed=openimage_cfg.get("seed"), + ) + if len(openimage_eval) > 0: + eval_datasets.append(openimage_eval) + eval_source_names.append("OpenImage I2I") + else: + logger.warning("OpenImage I2I evaluation dataset is empty; skipping.") + except Exception as exc: + logger.warning("Failed to build OpenImage I2I eval dataset: %s", exc) + + if not eval_datasets: + logger.warning("No i2i evaluation dataset available. Skipping.") + return + + eval_dataset = ( + eval_datasets[0] if len(eval_datasets) == 1 else CombinedX2IDataset(eval_datasets) + ) + logger.info("Using I2I evaluation datasets: %s", ", ".join(eval_source_names)) + + max_samples = getattr(config.experiment, "eval_num_i2i_samples", 8) + + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 8 + num_samples = min(max_samples, len(eval_dataset)) + if len(eval_dataset) <= num_samples: + sample_indices = list(range(len(eval_dataset))) + else: + sample_indices = random.sample(range(len(eval_dataset)), num_samples) + samples = [eval_dataset[i] for i in sample_indices] + prompts = [] + original_tensors = [] + target_tensors = [] + for sample in samples: + prompts.append(sample.get("edit_prompt") or sample.get("output_prompt") or "") + original_tensors.append(sample["input_image"]) + target_tensors.append(sample["output_image"]) + original_images = torch.stack(original_tensors, dim=0).to(accelerator.device) + original_tokens = vq_model_image.get_code(original_images) + len(uni_prompting.text_tokenizer) + seq_len = original_tokens.shape[-1] + mask_token_id = unwrapped_model.config.mask_token_id + placeholder = torch.full((num_samples, seq_len), mask_token_id, dtype=torch.long, device=accelerator.device) + input_ids, attention_mask = uni_prompting((prompts, original_tokens, placeholder), 'i2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting( + ([''] * num_samples, original_tokens, placeholder), 'i2i_gen' + ) + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids, uncond_attention_mask = None, None + cfg_scale = 0.0 + mask_schedule = _resolve_mask_schedule(config) + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + use_autocast = accelerator.device.type == "cuda" and accelerator.mixed_precision != "no" + autocast_ctx = torch.autocast("cuda", dtype=weight_dtype) if use_autocast else contextlib.nullcontext() + with autocast_ctx: + gen_token_ids = unwrapped_model.i2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=3.5, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=15, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + seq_len=seq_len, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, min=0, max=unwrapped_model.config.codebook_size - 1) + generated_images = vq_model_image.decode_code(gen_token_ids) + generated_images = torch.clamp((generated_images + 1.0) / 2.0, min=0.0, max=1.0) + gen_images_pil = [Image.fromarray((img.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)) for img in generated_images] + source_pil = [_tensor_to_pil(tensor) for tensor in original_tensors] + target_pil = [_tensor_to_pil(tensor) for tensor in target_tensors] + log_resolution = getattr(config.experiment, "eval_image_log_resolution", 512) + wandb_images = [] + for prompt, src, pred, tgt in zip(prompts, source_pil, gen_images_pil, target_pil): + composite = Image.new('RGB', (log_resolution * 3, log_resolution)) + src_resized = src.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + pred_resized = pred.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + tgt_resized = tgt.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + composite.paste(src_resized, (0, 0)) + composite.paste(pred_resized, (log_resolution, 0)) + composite.paste(tgt_resized, (log_resolution * 2, 0)) + wandb_images.append(wandb.Image(composite, caption=f"Prompt: {prompt}")) + accelerator.log({"eval/i2i_samples": wandb_images}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running S2S Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + dataset_cfg = getattr(config.dataset, "params", {}) + if isinstance(dataset_cfg, DictConfig): + dataset_cfg = dataset_cfg + s2s_eval_dir = getattr(dataset_cfg, "s2s_eval_dir", "MMaDA/validation_prompts/s2s") + + base_path = Path(s2s_eval_dir) + if not base_path.is_absolute(): + base_path = Path.cwd() / base_path + if not base_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / s2s_eval_dir + if alt_path.exists(): + base_path = alt_path + + if not base_path.exists(): + logger.warning(f"S2S evaluation directory '{s2s_eval_dir}' not found. Skipping S2S evaluation.") + return + + audio_exts = {".wav", ".flac", ".mp3", ".ogg", ".m4a"} + wav_files = sorted(p for p in base_path.iterdir() if p.is_file() and p.suffix.lower() in audio_exts) + if not wav_files: + logger.warning(f"No audio files found in '{base_path}'. Skipping S2S evaluation.") + return + + condition = getattr(dataset_cfg, "s2s_eval_condition", "gender-female_emotion-neutral_speed-normal_pitch-normal") + mask_token_id = unwrapped_model.config.mask_token_id + codebook_size = int(getattr(config.model.omada, "codebook_size", 8192)) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_size + audio_codebook_size = 4096 + if audio_codebook_size <= 0: + logger.warning("Computed audio codebook size is non-positive. Skipping S2S evaluation.") + return + + a_tokens = uni_prompting.text_tokenizer("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", return_tensors="pt").input_ids + soa_len = uni_prompting.sptids_dict['<|soa|>'].numel() + eoa_len = uni_prompting.sptids_dict['<|eoa|>'].numel() + asst_header_len = a_tokens.shape[1] + max_audio_len = getattr(uni_prompting, "max_audio_len", 256) + max_generatable = 256 + + offset = len(uni_prompting.text_tokenizer) + codebook_size + device = accelerator.device + + output_root = Path(config.experiment.output_dir) / "eval_s2s" / f"step_{global_step}" + output_root.mkdir(parents=True, exist_ok=True) + + table = wandb.Table(columns=["Audio ID", "Source Audio", "Generated Audio", "Token Count"]) + + for audio_path in wav_files: + try: + user_tokens = vq_model_audio.encode(str(audio_path)) + except Exception as exc: + logger.error(f"Failed to encode '{audio_path}': {exc}") + continue + + if not isinstance(user_tokens, torch.Tensor): + user_tokens = torch.tensor(user_tokens) + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + + user_tokens = user_tokens.to(device=device, dtype=torch.long) + if user_tokens.numel() == 0: + logger.warning(f"Encoded audio from '{audio_path}' produced no tokens. Skipping sample.") + continue + + # Use a fixed assistant placeholder length so generation is not limited by user input duration. + assistant_len = max_generatable + if assistant_len <= 0: + logger.warning(f"Assistant placeholder length for '{audio_path}' is non-positive. Skipping sample.") + continue + + user_shifted = user_tokens + offset + assistant_placeholder = torch.full( + (1, assistant_len), + mask_token_id, + dtype=torch.long, + device=device, + ) + + input_ids, attention_mask = uni_prompting( + ([user_shifted], [assistant_placeholder]), + 's2s_gen' + ) + + try: + generated_sequences = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=256, + steps=256, + block_length=256, + temperature=config.training.get("s2s_generation_temperature", 1.0), + cfg_scale=config.training.get("s2s_guidance_scale", 3.0), + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=audio_codebook_size, + ) + except Exception as exc: + logger.error(f"Generation failed for '{audio_path}': {exc}") + continue + + if not generated_sequences: + logger.warning(f"No tokens generated for '{audio_path}'. Skipping sample.") + continue + + gen_tokens = generated_sequences[0] + if isinstance(gen_tokens, torch.Tensor): + gen_tokens = gen_tokens.detach().cpu() + token_list = gen_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list empty for '{audio_path}'. Skipping sample.") + continue + + speech_unit_str = "".join([f"<|speech_{int(token)}|>" for token in token_list]) + output_path = output_root / f"{audio_path.stem}_reply.wav" + + try: + vq_model_audio.decode(speech_unit_str, condition=condition, output_wav_file=str(output_path)) + except Exception as exc: + logger.error(f"Decoding failed for '{audio_path}': {exc}") + continue + + table.add_data( + audio_path.name, + wandb.Audio(str(audio_path), caption="source"), + wandb.Audio(str(output_path), caption="generated"), + len(token_list), + ) + + row_count = getattr(table, "num_rows", None) + if row_count is None: + table_data = getattr(table, "data", None) + row_count = len(table_data) if table_data is not None else 0 + + if row_count > 0: + accelerator.log({"eval/s2s_samples": table}, step=global_step) + else: + logger.warning("S2S evaluation produced no loggable samples.") + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ TEXT EVALUATION LOGIC ++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_text(model, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running Text Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + dataset_cfg = getattr(config.dataset, "params", {}) + prompts_file = getattr(dataset_cfg, "text_eval_prompts_file", "MMaDA/validation_prompts/math.txt") + + prompts_path = Path(prompts_file) + if not prompts_path.is_absolute(): + prompts_path = Path.cwd() / prompts_path + if not prompts_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / prompts_file + if alt_path.exists(): + prompts_path = alt_path + + if not prompts_path.exists(): + logger.warning(f"Text evaluation prompts file '{prompts_file}' not found. Skipping text evaluation.") + return + + try: + with open(prompts_path, "r", encoding="utf-8") as handle: + raw_prompts = [line.strip() for line in handle if line.strip()] + except OSError as exc: + logger.warning(f"Failed to read text evaluation prompts from '{prompts_path}': {exc}. Skipping text evaluation.") + return + + if not raw_prompts: + logger.warning("Text evaluation prompt list is empty. Skipping text evaluation.") + return + + max_samples = getattr(config.experiment, "eval_num_text_samples", 4) + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 4 + questions = raw_prompts[:max_samples] + + chat_prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for question in questions + ] + + tokenizer = uni_prompting.text_tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + answers: list[str] = [] + + for chat_prompt in chat_prompts: + tokens = tokenizer( + chat_prompt, + return_tensors="pt", + padding=True, + truncation=True, + ) + + input_ids = tokens["input_ids"].to(accelerator.device) + out = generate(unwrapped_model, input_ids, steps=128, gen_length=128, block_length=128, temperature=1, cfg_scale=0., remasking='low_confidence') + answer = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True) + + answers.append(answer) + + table = wandb.Table(columns=["Index", "Question", "Answer"]) + for idx, (question, answer) in enumerate(zip(questions, answers)): + table.add_data(idx, question, answer) + + accelerator.log({"eval/text_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(32)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation"): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + eos_token = uni_prompting.text_tokenizer.eos_token + eos_marker = eos_token if eos_token is not None else "" + for i in range(len(decoded_texts)): + full_text = decoded_texts[i] + eos_idx = full_text.find(eos_marker) + cleaned_text = full_text[:eos_idx] if eos_idx != -1 else full_text + cleaned_text = cleaned_text.replace(eos_marker, "").strip() + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": cleaned_text, + }) + + if not local_results: + logger.warning("S2T evaluation produced no results.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["decoded_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + + samples_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for idx, res in enumerate(local_results): + sample_id = res.get("sample_id", idx) + samples_table.add_data(sample_id, res["gt_text"], res["decoded_text"]) + + accelerator.log({"eval/s2t_samples": samples_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(8)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for T2S: user prompt + text + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 384 - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=5.0, + temperature=1.0, + timesteps=50, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=383, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + if not local_results: + logger.warning("Skipping T2S evaluation logging because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """Text-to-speech evaluation using the MMU-style block refinement decoder.""" + + if not accelerator.is_main_process: + return + + logger.info("***** Running T2S Evaluation (MMU-style decoder) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(8)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as exc: + logger.error(f"Failed to load T2S dataset or Whisper model for MMU-style eval: {exc}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_mmu") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + + codebook_size = config.model.omada.codebook_size + speech_vocab_size = 4096 + + for batch in tqdm(t2s_dataloader, desc="T2S MMU Eval"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + + batch_size = len(prompts) + speech_token_length = 384 - 1 + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + output_ids = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=speech_token_length, + steps=384 - 1, + block_length=384 - 1, + temperature=1.0, + cfg_scale=3.5, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=speech_vocab_size, + ) + + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + if isinstance(gen_speech_tokens, torch.Tensor): + gen_speech_tokens = gen_speech_tokens.detach().cpu() + + token_list = gen_speech_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]} (MMU eval). Skipping.") + continue + + speech_unit_str = " ".join(map(str, token_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}_mmu.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + try: + vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + except Exception as exc: + logger.error(f"Decoding failed for sample {sample_ids[i]} (MMU eval): {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": output_wav_path, + }) + + if not local_results: + logger.warning("Skipping T2S MMU-style evaluation because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S (MMU-style) Final WER: {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_mmu_like_wer": wer, + "eval/t2s_mmu_like_word_errors": errors, + "eval/t2s_mmu_like_total_words": words, + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_mmu_like_samples": results_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2s(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running V2S Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + except Exception as exc: + logger.error(f"Failed to load Whisper model for V2S eval: {exc}") + return + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2S eval root '{video_root}' not found. Skipping V2S evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2S evaluation.") + return + + prompt_candidates = V2S_INSTRUCTION + results_table = wandb.Table(columns=["Video ID", "Question", "Whisper Transcript", "Generated Audio"]) + + device = unwrapped_model.device + mask_token_id = unwrapped_model.config.mask_token_id + eoa_token_id = int(uni_prompting.sptids_dict['<|eoa|>'][0].item()) + audio_codebook_size = 4096 + max_audio_tokens = int(getattr(uni_prompting, "max_audio_len_short", config.dataset.preprocessing.max_aud_length_short)) + max_new_tokens = max(1, max_audio_tokens - 1) + block_length = 128 if max_new_tokens % 128 == 0 else max_new_tokens + + output_dir = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio_v2s", f"step_{global_step}") + os.makedirs(output_dir, exist_ok=True) + + for file_name in tqdm(file_list[:], desc="V2S Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for idx in range(total_frames): + ret, frame = cap.read() + if idx in indices: + if not ret: + continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: + logger.warning(f"Skipping {file_name}: insufficient frames.") + continue + + video_tensor = torch.stack(frames).to(device) + try: + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + except Exception as exc: + logger.error(f"Failed to encode video {file_name}: {exc}") + continue + video_tokens = video_tokens.view(1, -1) + + question = random.choice(prompt_candidates) + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + audio_placeholder = torch.full( + (1, max_new_tokens), + mask_token_id, + dtype=torch.long, + device=device, + ) + try: + seq_ids, attn_mask = uni_prompting( + (video_tokens, [prompt_text], [audio_placeholder]), + 'v2s_gen' + ) + except Exception as exc: + logger.error(f"Prompt construction failed for {file_name}: {exc}") + continue + + input_ids = seq_ids.to(device) + attention_mask = attn_mask.to(device) + + try: + generated_list = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=256, + steps=256, + block_length=256, + temperature=1.0, + cfg_scale=3.0, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=config.model.omada.codebook_size, + audio_codebook_size=audio_codebook_size, + ) + except Exception as exc: + logger.error(f"Generation failed for {file_name}: {exc}") + continue + + generated_tokens = generated_list[0] + if isinstance(generated_tokens, torch.Tensor): + generated_tokens = generated_tokens.detach().cpu() + + collected = [] + for token in generated_tokens.tolist(): + if token == eoa_token_id or token >= audio_codebook_size: + break + if token >= 0: + collected.append(token) + + if not collected: + logger.warning(f"No valid audio tokens generated for {file_name}.") + continue + + speech_unit_for_decode = "".join(f"<|speech_{tok}|>" for tok in collected) + output_wav_path = os.path.join(output_dir, f"{Path(file_name).stem}_v2s.wav") + try: + vq_model_audio.decode( + speech_unit_for_decode, + condition='gender-female_emotion-neutral_speed-normal_pitch-normal', + output_wav_file=output_wav_path + ) + except Exception as exc: + logger.error(f"Decoding failed for {file_name}: {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + results_table.add_data( + file_name, + question, + whisper_text, + wandb.Audio(output_wav_path, caption=whisper_text) + ) + + if len(results_table.data) == 0: + logger.warning("V2S evaluation produced no samples to log.") + return + + accelerator.log({"eval/v2s_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + if accelerator.is_main_process: + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + + if accelerator.is_main_process: + evaluate_v2s(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_text(model, uni_prompting, config, accelerator, global_step) + evaluate_t2i(model, vq_model_image, uni_prompting, config, accelerator, global_step) + evaluate_i2i(model, vq_model_image, uni_prompting, config, accelerator, global_step) + evaluate_s2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + _configure_multiprocessing() + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = ( + config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s + + config.training.batch_size_s2s + ) - 1 # -1 since t2s/ s2t choice + + total_batch_size = ( + total_batch_size_per_gpu + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, max_audio_len_short=config.dataset.preprocessing.max_aud_length_short, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|i2i|>", "<|v2t|>", "<|v2s|>", "<|s2t|>", "<|t2s|>", "<|s2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + + # speech_vocab_start = int(config.model.omada.llm_vocab_size) + int(config.model.omada.codebook_size) # 126464 + 8192 = 134656 + # audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) # 4096 + + logger.info(f"SPEECHVOCABSTART: {speech_vocab_start}") + logger.info(f"int(config.model.omada.new_vocab_size): {int(config.model.omada.new_vocab_size)}") + logger.info(f"AUDIOCODEBOOKSIZE: {audio_codebook_size}") + + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + # Speech-token caching configuration + speech_cache_cfg = getattr(config.dataset, "speech_token_cache", {}) + if not isinstance(speech_cache_cfg, dict): + speech_cache_cfg = OmegaConf.to_container(speech_cache_cfg, resolve=True) + speech_cache_cfg = speech_cache_cfg or {} + + speech_cache_enabled = bool(speech_cache_cfg.get("enable", False)) + speech_cache_dir: Optional[Path] + if speech_cache_enabled: + cache_root = speech_cache_cfg.get("root", "cache/speech_tokens") + speech_cache_dir = Path(cache_root) + try: + speech_cache_dir.mkdir(parents=True, exist_ok=True) + except OSError: + speech_cache_dir = None + speech_cache_enabled = False + logger.warning("Failed to create speech cache directory at %s; disabling cache.", cache_root) + else: + speech_cache_dir = None + + speech_cache_max_items = int(speech_cache_cfg.get("max_items_in_memory", 4096)) + audio_token_cache_mem: Dict[str, torch.Tensor] = {} + + def _get_audio_cache_path(audio_path: Union[str, Path]) -> Optional[Path]: + if not isinstance(audio_path, (str, os.PathLike)): + return None + if not speech_cache_enabled or speech_cache_dir is None: + return None + key = os.path.abspath(str(audio_path)) + digest = hashlib.sha1(key.encode("utf-8")).hexdigest() + subdir = speech_cache_dir / digest[:2] / digest[2:4] + return subdir / f"{digest}.pt" + + def _load_cached_audio_tokens(audio_path: Union[str, Path]) -> Optional[torch.Tensor]: + if not isinstance(audio_path, (str, os.PathLike)): + return None + cache_key = os.path.abspath(str(audio_path)) + cached = audio_token_cache_mem.get(cache_key) + if cached is not None: + return cached.clone() + + cache_path = _get_audio_cache_path(audio_path) + if cache_path is None or not cache_path.exists(): + return None + try: + tokens = torch.load(cache_path, map_location="cpu") + if isinstance(tokens, torch.Tensor): + if len(audio_token_cache_mem) < speech_cache_max_items: + audio_token_cache_mem[cache_key] = tokens + return tokens.clone() + except Exception as exc: + logger.warning("Failed to load cached speech tokens from %s (%s); ignoring cache.", cache_path, exc) + return None + + def _store_cached_audio_tokens(audio_path: Union[str, Path], tokens: torch.Tensor) -> None: + if not isinstance(audio_path, (str, os.PathLike)): + return + cache_path = _get_audio_cache_path(audio_path) + if cache_path is None: + return + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") + torch.save(tokens.cpu(), tmp_path) + os.replace(tmp_path, cache_path) + except Exception as exc: + logger.warning("Failed to write speech token cache to %s (%s).", cache_path, exc) + return + cache_key = os.path.abspath(str(audio_path)) + if len(audio_token_cache_mem) < speech_cache_max_items: + audio_token_cache_mem[cache_key] = tokens.cpu() + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16, config='/home/work/AIDAS/ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model/config.json').to(accelerator.device) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + def build_distributed_sampler(dataset, *, shuffle=True, drop_last=True): + """Create a DistributedSampler only when running with multiple processes.""" + if dataset is None or accelerator.num_processes <= 1: + return None + return DistributedSampler( + dataset, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=shuffle, + drop_last=drop_last, + ) + + batch_size_t2i_cfg = config.training.batch_size_t2i + batch_size_lm_cfg = config.training.batch_size_lm + batch_size_mmu_cfg = config.training.batch_size_mmu + batch_size_t2s_cfg = config.training.batch_size_t2s + batch_size_s2t_cfg = config.training.batch_size_s2t + batch_size_v2t_cfg = config.training.batch_size_v2t + batch_size_s2s_cfg = config.training.batch_size_s2s + batch_size_v2s_cfg = batch_size_v2t_cfg + + total_batch_size = ( + total_batch_size_per_gpu + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + pin_memory = bool(getattr(dataset_config, "pin_memory", False)) + persistent_workers = bool(getattr(dataset_config, "persistent_workers", False)) + dataloader_timeout = int(getattr(dataset_config, "dataloader_timeout", 120)) + + if persistent_workers and dataloader_timeout > 0: + logger.warning( + "persistent_workers=True requires dataloader_timeout=0; overriding timeout=%s", + dataloader_timeout, + ) + dataloader_timeout = 0 + + if ( + not persistent_workers + and int(getattr(dataset_config, "num_workers", 0)) > 0 + and str(config.dataset.combined_loader_mode) == "max_size_cycle" + ): + logger.warning( + "Using combined_loader_mode='max_size_cycle' with num_workers>0 and " + "persistent_workers=False can exhaust OS semaphores when loaders cycle. " + "Set dataset.params.persistent_workers=True to keep worker processes alive." + ) + + # Text-to-image / Image-to-image datasets + logger.info("Loading Text-to-image / Image-to-image datasets") + dataset_t2i = None + dataset_i2i = None + train_dataloader_t2i = None + train_dataloader_i2i = None + sampler_t2i: Optional[DistributedSampler] = None # type: ignore[assignment] + sampler_i2i: Optional[DistributedSampler] = None # type: ignore[assignment] + if batch_size_t2i_cfg > 0: + raw_t2i_choice = dataset_config.get("t2i_dataset", "hqedit") + if isinstance(raw_t2i_choice, str): + split_tokens = [token.strip() for token in raw_t2i_choice.replace(",", "+").split("+")] + dataset_choices = [token for token in split_tokens if token] + else: + dataset_choices = [str(token).strip() for token in raw_t2i_choice if str(token).strip()] + + if not dataset_choices: + raise ValueError("t2i_dataset configuration produced no valid dataset names.") + + t2i_datasets: list[Dataset] = [] + i2i_datasets: list[Dataset] = [] + t2i_source_names: list[str] = [] + i2i_source_names: list[str] = [] + for choice in dataset_choices: + choice_lower = choice.lower() + if choice_lower in {"hqedit", "hq-edit", "hq_edit"}: + i2i_datasets.append( + HQEditX2IDataset( + split=dataset_config.get("hqedit_split", "train"), + resolution=dataset_config.resolution, + ) + ) + logger.info("Using HQ-Edit dataset for T2I/i2i branch (%s split)", dataset_config.get("hqedit_split", "train")) + i2i_source_names.append(choice) + elif choice_lower in {"text2image2m", "text-to-image-2m", "text_to_image_2m"}: + t2i_datasets.append( + TextToImage2MDataset( + split=dataset_config.get("t2i_split", "train"), + resolution=dataset_config.resolution, + dataset_name=dataset_config.get("t2i_dataset_name", "jackyhate/text-to-image-2M"), + cache_dir=dataset_config.get("t2i_cache_dir", None), + local_files_only=bool(dataset_config.get("t2i_local_files_only", False)), + ) + ) + logger.info( + "Using text-to-image-2M dataset for T2I branch (split=%s, dataset=%s)", + dataset_config.get("t2i_split", "train"), + dataset_config.get("t2i_dataset_name", "jackyhate/text-to-image-2M"), + ) + t2i_source_names.append(choice) + elif choice_lower in {"openimage", "openimage_i2i", "openimage-edit", "openimage_local"}: + raw_openimage_cfg = getattr(dataset_config, "openimage_i2i", None) + if raw_openimage_cfg is None: + raise ValueError("dataset.params.openimage_i2i must be configured to use the OpenImage dataset.") + if not isinstance(raw_openimage_cfg, dict): + openimage_cfg = OmegaConf.to_container(raw_openimage_cfg, resolve=True) or {} + else: + openimage_cfg = raw_openimage_cfg + + i2i_datasets.append( + OpenImageI2IDataset( + resolution=dataset_config.resolution, + image_root=openimage_cfg.get("image_root"), + sft_jsonl=openimage_cfg.get("sft_jsonl"), + pref_jsonl=openimage_cfg.get("pref_jsonl"), + multi_turn_jsonl=openimage_cfg.get("multi_turn_jsonl"), + prefer_summarized_text=bool(openimage_cfg.get("prefer_summarized_text", True)), + pref_positive_only=bool(openimage_cfg.get("pref_positive_only", True)), + skip_missing=bool(openimage_cfg.get("skip_missing", True)), + max_samples_per_source=openimage_cfg.get("max_samples_per_source"), + max_total_samples=openimage_cfg.get("max_total_samples"), + seed=openimage_cfg.get("seed"), + ) + ) + cfg_paths = [ + openimage_cfg.get("sft_jsonl"), + openimage_cfg.get("pref_jsonl"), + openimage_cfg.get("multi_turn_jsonl"), + ] + cfg_paths = [str(path) for path in cfg_paths if path] + logger.info( + "Using OpenImage local edit dataset for i2i (jsonl=%s)", + ", ".join(cfg_paths) if cfg_paths else "n/a", + ) + i2i_source_names.append(choice) + else: + raise ValueError(f"Unsupported t2i_dataset '{choice}'") + + if t2i_datasets: + dataset_t2i = ( + t2i_datasets[0] + if len(t2i_datasets) == 1 + else CombinedX2IDataset(t2i_datasets) + ) + logger.info( + "T2I dataloading sources: %s", + ", ".join(t2i_source_names) if t2i_source_names else "n/a", + ) + + sampler_t2i = build_distributed_sampler( + dataset_t2i, + shuffle=True, + drop_last=True, + ) + + train_dataloader_t2i = DataLoader( + dataset_t2i, + batch_size=batch_size_t2i_cfg, + sampler=sampler_t2i, + shuffle=sampler_t2i is None, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_x2i, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if i2i_datasets: + dataset_i2i = ( + i2i_datasets[0] + if len(i2i_datasets) == 1 + else CombinedX2IDataset(i2i_datasets) + ) + logger.info( + "I2I dataloading sources: %s", + ", ".join(i2i_source_names) if i2i_source_names else "n/a", + ) + + sampler_i2i = build_distributed_sampler( + dataset_i2i, + shuffle=True, + drop_last=True, + ) + + train_dataloader_i2i = DataLoader( + dataset_i2i, + batch_size=batch_size_t2i_cfg, + sampler=sampler_i2i, + shuffle=sampler_i2i is None, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_x2i, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Language modeling dataset (HF instruction mixture) + logger.info("Loading LM dataset") + dataset_lm = None + train_dataloader_lm = None + if batch_size_lm_cfg > 0: + instruction_cfg = getattr(dataset_config, "hf_instruction_lm", {}) + if not isinstance(instruction_cfg, dict): + instruction_cfg = OmegaConf.to_container(instruction_cfg, resolve=True) + instruction_cfg = instruction_cfg or {} + + seed_lm = instruction_cfg.get("seed") + if seed_lm is None: + seed_lm = getattr(config.training, "seed", 42) or 42 + + dataset_lm = HFInstructionTextDataset( + split=instruction_cfg.get("split", "train"), + max_samples_per_source=instruction_cfg.get("max_samples_per_source"), + max_total_samples=instruction_cfg.get("max_total_samples"), + seed=int(seed_lm), + ) + + sampler_lm = build_distributed_sampler( + dataset_lm, + shuffle=True, + drop_last=True, + ) + + train_dataloader_lm = DataLoader( + dataset_lm, + batch_size=batch_size_lm_cfg, + sampler=sampler_lm, + shuffle=sampler_lm is None, + collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Video Dataset + logger.info("Loading Video dataset") + dataset_v2t = None + dataset_v2s = None + train_dataloader_v2t = None + train_dataloader_v2s = None + sampler_v2t = None + sampler_v2s = None + speech_cfg = getattr(dataset_config, "video_speech_dataset", {}) + if not isinstance(speech_cfg, dict): + speech_cfg = OmegaConf.to_container(speech_cfg, resolve=True) + speech_cfg = speech_cfg or {} + + if batch_size_v2t_cfg > 0: + v2t_sample_method = speech_cfg.get( + "v2t_sample_method", + speech_cfg.get("sample_method", "uniform"), + ) + llavavid_max_seconds_cfg = speech_cfg.get("llavavid_max_video_seconds") + if llavavid_max_seconds_cfg is not None: + try: + llavavid_max_seconds_cfg = float(llavavid_max_seconds_cfg) + except (TypeError, ValueError): + llavavid_max_seconds_cfg = None + dataset_v2t = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method=v2t_sample_method, + dataset_name=speech_cfg.get("llavavid_dataset_name", "llavavid"), + llavavid_path=speech_cfg.get("llavavid_path", "lmms-lab/LLaVA-Video-178K"), + num_frames=8, + llavavid_local_files_only=bool(speech_cfg.get("llavavid_local_files_only", False)), + llavavid_skip_configs=speech_cfg.get("llavavid_skip_configs"), + llavavid_skip_video_patterns=speech_cfg.get("llavavid_skip_video_patterns"), + max_video_seconds=llavavid_max_seconds_cfg, + ) + + sampler_v2t = build_distributed_sampler( + dataset_v2t, + shuffle=True, + drop_last=True, + ) + + train_dataloader_v2t = DataLoader( + dataset_v2t, + batch_size=batch_size_v2t_cfg, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_v2t, + sampler=sampler_v2t, + shuffle=sampler_v2t is None, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if batch_size_v2s_cfg > 0: + video_root_cfg = speech_cfg.get("video_root", "/home/work/AIDAS/data/video/openvid1m/video/video") + audio_root_cfg = speech_cfg.get("audio_root", "/home/work/AIDAS/data/video/openvid1m") + speech_dir_name_cfg = speech_cfg.get("speech_dir_name", "speech_wavs") + index_path_cfg = speech_cfg.get("index_path", "/home/work/AIDAS/data/video/openvid1m/openvid_speech_new.csv") + index_cache_path_cfg = speech_cfg.get("index_cache_path") + max_video_seconds_cfg = speech_cfg.get("max_video_seconds") + validate_paths_cfg = bool(speech_cfg.get("validate_paths", False)) + precomputed_root_cfg = speech_cfg.get("precomputed_tokens_root") + if isinstance(precomputed_root_cfg, str) and precomputed_root_cfg.upper() == "NONE": + precomputed_root_cfg = None + if not speech_cfg.get("use_precomputed_tokens", True): + precomputed_root_cfg = None + sample_method_v2s = speech_cfg.get("sample_method", "uniform") + dataset_v2s = VideoSpeechDataset( + transform=image_transform, + resolution=preproc_config.resolution, + num_frames=speech_cfg.get("num_frames_speech", 4), + video_root=video_root_cfg, + audio_root=audio_root_cfg, + speech_dir_name=speech_dir_name_cfg, + index_path=index_path_cfg, + sample_method=sample_method_v2s, + precomputed_tokens_root=precomputed_root_cfg, + validate_paths=validate_paths_cfg, + index_cache_path=index_cache_path_cfg, + max_video_seconds=max_video_seconds_cfg, + ) + + sampler_v2s = build_distributed_sampler( + dataset_v2s, + shuffle=True, + drop_last=True, + ) + + train_dataloader_v2s = DataLoader( + dataset_v2s, + batch_size=batch_size_v2s_cfg, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_v2s, + sampler=sampler_v2s, + shuffle=sampler_v2s is None, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Speech Dataset + logger.info("Loading Speech dataset") + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + # Speech-to-Speech Dataset (EMOVA + Instruct S2S) + dataset_s2s = None + sampler_s2s = None + train_dataloader_s2s = None + if config.training.batch_size_s2s > 0: + dataset_s2s = Speech2SpeechDataset(dataset_config.get("speech2speech", [])) + + # Multi-image interleaved dataset (MMU-style) + logger.info("Loading MMU dataset") + dataset_mmu = None + sampler_mmu = None + train_dataloader_mmu = None + if config.training.batch_size_mmu > 0: + mmu_params = dataset_config.get("mmu_interleaved", {}) + if mmu_params is None: + mmu_kwargs = {} + elif isinstance(mmu_params, dict): + mmu_kwargs = mmu_params + else: + mmu_kwargs = OmegaConf.to_container(mmu_params, resolve=True) + dataset_mmu = TextImageInterleavedDataset(**mmu_kwargs) + + logger.info("Dataset Prepared.") + require_cached_audio_tokens = bool(getattr(config.dataset.params, "require_cached_audio_tokens", False)) + + def _prepare_audio_flow(paths, tokens, texts): + """Align path/token/text lists and optionally drop samples without cached tokens.""" + path_list = list(paths) if isinstance(paths, (list, tuple)) else list(paths or []) + token_iterable = tokens if isinstance(tokens, (list, tuple)) else list(tokens or []) + text_iterable = texts if isinstance(texts, (list, tuple)) else list(texts or []) + + triplets: list[tuple[Any, Any, str]] = [] + for path, token, text in zip_longest(path_list, token_iterable, text_iterable, fillvalue=None): + if path is None: + continue + triplets.append((path, token, text if text is not None else "")) + + skipped = 0 + if require_cached_audio_tokens and triplets: + filtered = [(p, t, txt) for (p, t, txt) in triplets if t is not None] + skipped = len(triplets) - len(filtered) + triplets = filtered + + if not triplets: + return [], [], [], skipped + + aligned_paths = [p for (p, _, _) in triplets] + aligned_tokens = [t for (_, t, _) in triplets] + aligned_texts = [txt for (_, _, txt) in triplets] + return aligned_paths, aligned_tokens, aligned_texts, skipped + + # Use distinct DistributedSamplers for each speech dataloader to avoid iterator interference + if accelerator.num_processes > 1: + sampler_s2t = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + sampler_t2s = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + if dataset_s2s is not None: + sampler_s2s = DistributedSampler( + dataset_s2s, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + if dataset_mmu is not None: + sampler_mmu = DistributedSampler( + dataset_mmu, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + else: + sampler_s2t = None + sampler_t2s = None + sampler_s2s = None + sampler_mmu = None + + train_dataloader_s2t = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_s2t, + shuffle=False, + sampler=sampler_s2t, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + train_dataloader_t2s = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_t2s, + shuffle=False, + sampler=sampler_t2s, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if dataset_s2s is not None: + train_dataloader_s2s = DataLoader( + dataset_s2s, + batch_size=batch_size_s2s_cfg, + shuffle=False, + sampler=sampler_s2s, + collate_fn=s2s_collate_fn, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if dataset_mmu is not None: + train_dataloader_mmu = DataLoader( + dataset_mmu, + batch_size=config.training.batch_size_mmu, + shuffle=False, + sampler=sampler_mmu, + collate_fn=collate_fn_mmu_mult, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Combine these dataloaders into a single iterable model + iterables = {} + if train_dataloader_lm is not None: + iterables["lm_flow"] = train_dataloader_lm + if train_dataloader_mmu is not None: + iterables["mmu_flow"] = train_dataloader_mmu + if train_dataloader_s2s is not None: + iterables["s2s_flow"] = train_dataloader_s2s + + if not iterables: + raise ValueError( + "CombinedLoader requires at least one non-speech iterable when speech flows are randomized. " + "Enable another dataset (e.g., t2i, lm, mmu) or disable speech randomization." + ) + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + def _num_steps(dataset_obj, batch_size_cfg): + if dataset_obj is None or batch_size_cfg <= 0: + return 0 + total_bs = batch_size_cfg * accelerator.num_processes * config.training.gradient_accumulation_steps + if total_bs <= 0: + return 0 + length = len(dataset_obj) + if length == 0: + return 0 + return math.ceil(length / total_bs) + + num_update_steps_per_epoch_t2i = _num_steps(dataset_t2i, config.training.batch_size_t2i) + num_update_steps_per_epoch_i2i = _num_steps(dataset_i2i, config.training.batch_size_t2i) + num_update_steps_per_epoch_lm = _num_steps(dataset_lm, config.training.batch_size_lm) + num_update_steps_per_epoch_s2t = _num_steps(dataset_sm, config.training.batch_size_s2t) + num_update_steps_per_epoch_t2s = _num_steps(dataset_sm, config.training.batch_size_t2s) + num_update_steps_per_epoch_s2s = _num_steps(dataset_s2s, batch_size_s2s_cfg) + num_update_steps_per_epoch_v2t = _num_steps(dataset_v2t, batch_size_v2t_cfg) + num_update_steps_per_epoch_v2s = _num_steps(dataset_v2s, batch_size_v2s_cfg) + num_update_steps_per_epoch_mmu = _num_steps(dataset_mmu, config.training.batch_size_mmu) + + # Calculate num_train_epochs + num_update_steps_per_epoch = max( + num_update_steps_per_epoch_t2i, + num_update_steps_per_epoch_lm, + num_update_steps_per_epoch_s2t, + num_update_steps_per_epoch_t2s, + num_update_steps_per_epoch_v2t, + num_update_steps_per_epoch_v2s, + num_update_steps_per_epoch_s2s, + num_update_steps_per_epoch_mmu, + num_update_steps_per_epoch_i2i, + ) + + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + logger.info(f"len of T2I: {len(dataset_t2i) if dataset_t2i is not None else 0}") + logger.info(f"len of I2I: {len(dataset_i2i) if dataset_i2i is not None else 0}") + logger.info(f"len of LM: {len(dataset_lm)}") + logger.info(f"len of Speech: {len(dataset_sm)}") + logger.info(f"len of Video Caption: {len(dataset_v2t) if dataset_v2t is not None else 0}") + logger.info(f"len of Video Speech: {len(dataset_v2s) if dataset_v2s is not None else 0}") + logger.info(f"len of S2S: {len(dataset_s2s)}") + logger.info(f"len of MMU: {len(dataset_mmu)}") + + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def _maybe_trim_audio_file(audio_path: Union[str, os.PathLike], max_duration: float) -> tuple[Union[str, os.PathLike], Optional[str]]: + """Return a path to an audio file trimmed to max_duration seconds. + + If trimming succeeds, returns (trimmed_path, temp_path) where trimmed_path is the + file to use for encoding and temp_path should be deleted afterwards. If trimming + fails, returns (audio_path, None). + """ + if max_duration <= 0: + return audio_path, None + trim_timeout = float(getattr(config.dataset.preprocessing, "audio_trim_timeout_sec", 30.0)) + try: + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp_path = tmp.name + tmp.close() + cmd = [ + "ffmpeg", + "-y", + "-hide_banner", + "-loglevel", + "error", + "-i", + str(audio_path), + "-t", + str(max_duration), + "-c", + "copy", + tmp_path, + ] + subprocess.run(cmd, check=True, timeout=trim_timeout) + return tmp_path, tmp_path + except Exception as exc: + warnings.warn(f"Failed to trim audio {audio_path} to {max_duration}s: {exc}") + try: + if 'tmp_path' in locals() and os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError: + pass + return audio_path, None + + def _format_path_for_log(path: Union[str, os.PathLike, torch.Tensor, None]) -> str: + if isinstance(path, (str, os.PathLike)): + try: + return os.fspath(path) + except TypeError: + return str(path) + if isinstance(path, torch.Tensor): + return f"" + if isinstance(path, np.ndarray): + return f"" + if isinstance(path, Sequence) and not isinstance(path, (str, bytes, os.PathLike)): + try: + return f"" + except Exception: + return "" + return repr(path) + + def safe_audio_encode(audio_path: Union[str, torch.Tensor, np.ndarray, Sequence[int]], flow_name: str): + if isinstance(audio_path, torch.Tensor): + return audio_path.cpu().clone(), None + if isinstance(audio_path, np.ndarray): + try: + tensor = torch.from_numpy(audio_path).to(dtype=torch.long) + except Exception as exc: + raise RuntimeError(f"Failed to convert numpy audio tokens to tensor for flow '{flow_name}': {exc}") from exc + return tensor, None + if isinstance(audio_path, Sequence) and not isinstance(audio_path, (str, bytes, os.PathLike)): + try: + tensor = torch.as_tensor(audio_path, dtype=torch.long) + except Exception as exc: + raise RuntimeError(f"Failed to convert cached audio tokens to tensor for flow '{flow_name}': {exc}") from exc + return tensor, None + path_repr = _format_path_for_log(audio_path) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode request: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + + max_retries = int(getattr(config.dataset.preprocessing, "audio_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "audio_encode_retry_backoff_sec", 0.5)) + duration_limit = float(getattr(config.dataset.preprocessing, "max_audio_duration_sec", 15.0)) + + cached = _load_cached_audio_tokens(audio_path) + if cached is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode hit cache: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + return cached, None + + for attempt in range(1, max_retries + 1): + trimmed_path: Union[str, os.PathLike] = audio_path + temp_path: Optional[str] = None + try: + if isinstance(audio_path, (str, os.PathLike)): + trimmed_path, temp_path = _maybe_trim_audio_file(audio_path, duration_limit) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode attempt %d/%d (trimmed=%s): %s", + accelerator.process_index, + flow_name, + attempt, + max_retries, + "yes" if temp_path is not None else "no", + _format_path_for_log(trimmed_path), + ) + tokens = vq_model_audio.encode(str(trimmed_path)).cpu() + _store_cached_audio_tokens(audio_path, tokens) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode success: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + return tokens, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + finally: + if temp_path is not None and os.path.exists(temp_path): + try: + os.remove(temp_path) + except OSError: + pass + + def safe_video_get_code(video_tensor_sample: torch.Tensor, sample_index: int): + max_retries = int(getattr(config.dataset.preprocessing, "video_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "video_encode_retry_backoff_sec", 0.5)) + for attempt in range(1, max_retries + 1): + try: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] video encode request sample=%d attempt=%d/%d", + accelerator.process_index, + sample_index, + attempt, + max_retries, + ) + video_token = vq_model_image.get_code(video_tensor_sample) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] video encode success sample=%d", + accelerator.process_index, + sample_index, + ) + return video_token, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] v2t video encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + logger.warning( + "[rank %s] video encode retry sample=%d attempt=%d/%d error=%s", + accelerator.process_index, + sample_index, + attempt, + max_retries, + exc, + ) + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + + def safe_image_get_code(image_tensor_sample: torch.Tensor, sample_index: int): + max_retries = int(getattr(config.dataset.preprocessing, "image_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "image_encode_retry_backoff_sec", 0.5)) + for attempt in range(1, max_retries + 1): + try: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] image encode request sample=%d attempt=%d/%d", + accelerator.process_index, + sample_index, + attempt, + max_retries, + ) + if image_tensor_sample.dim() == 3: + image_tensor_sample = image_tensor_sample.unsqueeze(0) + elif image_tensor_sample.dim() != 4: + raise ValueError( + f"Expected image tensor with 3 or 4 dims, got shape {tuple(image_tensor_sample.shape)}" + ) + image_token = vq_model_image.get_code(image_tensor_sample) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] image encode success sample=%d", + accelerator.process_index, + sample_index, + ) + return image_token, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] s2s image encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + logger.warning( + "[rank %s] image encode retry sample=%d attempt=%d/%d error=%s", + accelerator.process_index, + sample_index, + attempt, + max_retries, + exc, + ) + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + + def _decode_single_image(single_like): + if single_like is None: + return None + if isinstance(single_like, Image.Image): + return single_like.convert('RGB') + + data_bytes = None + + if isinstance(single_like, (bytes, bytearray)): + data_bytes = bytes(single_like) + elif isinstance(single_like, str): + try: + data_bytes = base64.b64decode(single_like) + except (binascii.Error, ValueError): + if os.path.isfile(single_like): + try: + with open(single_like, 'rb') as fh: + data_bytes = fh.read() + except OSError: + data_bytes = None + elif isinstance(single_like, dict): + binary_payload = single_like.get('bytes') + if binary_payload is not None: + data_bytes = binary_payload + else: + path_value = single_like.get('path') + if path_value and os.path.isfile(path_value): + try: + with open(path_value, 'rb') as fh: + data_bytes = fh.read() + except OSError: + data_bytes = None + + if data_bytes is None: + return None + + try: + with Image.open(BytesIO(data_bytes)) as img: + return img.convert('RGB') + except Exception: + return None + + def maybe_decode_image(image_like): + if isinstance(image_like, (list, tuple)): + return [_decode_single_image(item) for item in image_like] + return _decode_single_image(image_like) + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_i2i( + source_images: torch.FloatTensor, + target_images: torch.FloatTensor, + prompts: list[str], + is_train: bool = True, + ): + """Build masked i2i sequences from source/target image pairs.""" + + # Tokenize source/target images with VQ model and offset by text vocab size + source_tokens = vq_model_image.get_code(source_images) + len(uni_prompting.text_tokenizer) + target_tokens = vq_model_image.get_code(target_images) + len(uni_prompting.text_tokenizer) + + cond_dropout_prob = config.training.get( + "i2i_cond_dropout_prob", + config.training.cond_dropout_prob, + ) + + if is_train and torch.rand(1, device=source_tokens.device).item() < cond_dropout_prob: + effective_prompts = [''] * len(prompts) + masked_target_source = source_tokens + else: + effective_prompts = list(prompts) + masked_target_source = target_tokens + + masked_target_tokens, labels, _, mask_prob = mask_or_random_replace_tokens( + masked_target_source, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + + input_ids, attention_masks, labels = uni_prompting( + (effective_prompts, source_tokens, masked_target_tokens, labels), + 'i2i' + ) + + return input_ids, labels, mask_prob, attention_masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_v2s( + input_ids_v2s: torch.Tensor, + prompt_masks: torch.Tensor, + labels_v2s: torch.Tensor, + eps: float = 1e-3, + ): + b, l = input_ids_v2s.shape + device_local = input_ids_v2s.device + + p_mask = eps + (1.0 - eps) * torch.rand(b, device=device_local) + p_mask = p_mask.unsqueeze(1).expand(b, l) + + rand_mat = torch.rand((b, l), device=device_local) + answer_region = (prompt_masks == 0) + masked_indices = (rand_mat < p_mask) & answer_region + + noisy_batch = input_ids_v2s.clone() + noisy_batch[masked_indices] = mask_id + + answer_lengths = answer_region.sum(dim=-1, keepdim=True) + expanded_lengths = answer_lengths.expand(b, l) + + return noisy_batch.long(), labels_v2s.long(), p_mask, expanded_lengths.long() + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2s( + input_ids_s2s, prompt_masks, labels_s2s, eps=1e-3 + ): + b, l = input_ids_s2s.shape + t = torch.rand(b, device=input_ids_s2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_s2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_s2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_s2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_s2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + v2t_iterator: Optional[Iterator] = None + v2s_iterator: Optional[Iterator] = None + t2i_iterator: Optional[Iterator] = None + i2i_iterator: Optional[Iterator] = None + + def _next_from_v2t(): + nonlocal v2t_iterator + if train_dataloader_v2t is None: + return None + try: + return next(v2t_iterator) + except StopIteration: + v2t_iterator = iter(train_dataloader_v2t) + return next(v2t_iterator) + + def _next_from_v2s(): + nonlocal v2s_iterator + if train_dataloader_v2s is None: + return None + try: + return next(v2s_iterator) + except StopIteration: + v2s_iterator = iter(train_dataloader_v2s) + return next(v2s_iterator) + + def _next_from_t2i(): + nonlocal t2i_iterator + if train_dataloader_t2i is None: + return None + try: + return next(t2i_iterator) + except StopIteration: + t2i_iterator = iter(train_dataloader_t2i) + return next(t2i_iterator) + + def _next_from_i2i(): + nonlocal i2i_iterator + if train_dataloader_i2i is None: + return None + try: + return next(i2i_iterator) + except StopIteration: + i2i_iterator = iter(train_dataloader_i2i) + return next(i2i_iterator) + + def _next_from_s2t(): + nonlocal s2t_iterator + if train_dataloader_s2t is None: + return None + try: + return next(s2t_iterator) + except StopIteration: + s2t_iterator = iter(train_dataloader_s2t) + return next(s2t_iterator) + + def _next_from_t2s(): + nonlocal t2s_iterator + if train_dataloader_t2s is None: + return None + try: + return next(t2s_iterator) + except StopIteration: + t2s_iterator = iter(train_dataloader_t2s) + return next(t2s_iterator) + + v2t_iterator = iter(train_dataloader_v2t) if train_dataloader_v2t is not None else None + v2s_iterator = iter(train_dataloader_v2s) if train_dataloader_v2s is not None else None + t2i_iterator = iter(train_dataloader_t2i) if train_dataloader_t2i is not None else None + i2i_iterator = iter(train_dataloader_i2i) if train_dataloader_i2i is not None else None + s2t_iterator = iter(train_dataloader_s2t) if train_dataloader_s2t is not None else None + t2s_iterator = iter(train_dataloader_t2s) if train_dataloader_t2s is not None else None + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + # Ensure all samplers reshuffle in a rank-consistent way each epoch + try: + if isinstance(sampler_t2i, DistributedSampler): + sampler_t2i.set_epoch(epoch) + if isinstance(sampler_i2i, DistributedSampler): + sampler_i2i.set_epoch(epoch) + if isinstance(sampler_v2t, DistributedSampler): + sampler_v2t.set_epoch(epoch) + if isinstance(sampler_v2s, DistributedSampler): + sampler_v2s.set_epoch(epoch) + if accelerator.num_processes > 1: + if sampler_s2t is not None: + sampler_s2t.set_epoch(epoch) + if sampler_t2s is not None: + sampler_t2s.set_epoch(epoch) + if sampler_s2s is not None: + sampler_s2s.set_epoch(epoch) + except Exception: + pass + model.train() + combined_iterator = iter(combined_dataloader) + while True: + skip_local = 0 + timeout_encountered = False + timeout_message: Optional[str] = None + try: + batch, batch_idx, dataloader_idx = next(combined_iterator) + except StopIteration: + break + except RuntimeError as exc: + if "DataLoader timed out" in str(exc): + skip_local = 1 + timeout_encountered = True + timeout_message = str(exc) + batch = None + batch_idx = None + dataloader_idx = None + else: + raise + + if batch is None: + skip_local = 1 + + skip_tensor = torch.tensor(skip_local, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction="sum") + if skip_sum.item() > 0: + timeout_tensor = torch.tensor(1 if timeout_encountered else 0, device=accelerator.device, dtype=torch.int32) + timeout_sum = accelerator.reduce(timeout_tensor, reduction="sum") + if accelerator.is_main_process: + if timeout_sum.item() > 0: + logger.warning( + "Skipping global step %s due to DataLoader timeout: %s", + global_step, + timeout_message or "timeout on non-main rank", + ) + else: + logger.warning( + "Skipping global step %s due to empty batch from CombinedLoader.", + global_step, + ) + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + v2t_batch = None + v2s_batch = None + selected_v2_branch: Optional[str] = None + t2i_batch = None + i2i_batch = None + selected_x2i_branch: Optional[str] = None + + v2_choices: list[str] = [] + if train_dataloader_v2t is not None: + v2_choices.append("v2t") + if train_dataloader_v2s is not None: + v2_choices.append("v2s") + + if v2_choices: + choice_idx = global_step % len(v2_choices) + + selected_v2_branch = v2_choices[choice_idx] + if selected_v2_branch == "v2t": + v2t_batch = _next_from_v2t() + else: + v2s_batch = _next_from_v2s() + + batch["v2t_flow"] = v2t_batch + batch["v2s_flow"] = v2s_batch + + # Initialize speech flows with empty placeholders; they will be populated if selected. + batch["s2t_flow"] = _empty_audio_batch() + batch["t2s_flow"] = _empty_audio_batch() + + speech_choices: list[str] = [] + if train_dataloader_s2t is not None: + speech_choices.append("s2t") + if train_dataloader_t2s is not None: + speech_choices.append("t2s") + + selected_speech_branch: Optional[str] = None + if speech_choices: + choice_idx = global_step % len(speech_choices) + + selected_speech_branch = speech_choices[choice_idx] + if selected_speech_branch == "s2t": + speech_batch = _next_from_s2t() + if speech_batch is None: + skip_local = 1 + else: + batch["s2t_flow"] = speech_batch + else: + speech_batch = _next_from_t2s() + if speech_batch is None: + skip_local = 1 + else: + batch["t2s_flow"] = speech_batch + + x2i_choices: list[str] = [] + if train_dataloader_t2i is not None: + x2i_choices.append("t2i") + if train_dataloader_i2i is not None: + x2i_choices.append("i2i") + + if x2i_choices: + choice_idx = global_step % len(x2i_choices) + + selected_x2i_branch = x2i_choices[choice_idx] + if selected_x2i_branch == "t2i": + t2i_batch = _next_from_t2i() + else: + i2i_batch = _next_from_i2i() + + # Synchronize skip decision across all ranks to avoid collective mismatches + required_flows = ["t2s_flow", "s2t_flow"] + if train_dataloader_lm is not None: + required_flows.append("lm_flow") + if train_dataloader_mmu is not None: + required_flows.append("mmu_flow") + if train_dataloader_s2s is not None: + required_flows.append("s2s_flow") + + local_skip = 0 + if selected_v2_branch == "v2t" and v2t_batch is None: + local_skip = 1 + elif selected_v2_branch == "v2s" and v2s_batch is None: + local_skip = 1 + else: + for key in required_flows: + if batch.get(key) is None: + local_skip = 1 + break + if selected_x2i_branch == "t2i": + if t2i_batch is None: + local_skip = 1 + else: + t2i_images = t2i_batch["t2i"].get("images") + if not isinstance(t2i_images, torch.Tensor) or t2i_images.shape[0] == 0: + local_skip = 1 + elif selected_x2i_branch == "i2i": + if i2i_batch is None: + local_skip = 1 + else: + i2i_sources = i2i_batch["i2i"].get("source_images") + i2i_targets = i2i_batch["i2i"].get("target_images") + if ( + not isinstance(i2i_sources, torch.Tensor) + or not isinstance(i2i_targets, torch.Tensor) + or i2i_sources.shape[0] == 0 + or i2i_targets.shape[0] == 0 + ): + local_skip = 1 + try: + skip_tensor = torch.tensor(local_skip, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction='sum') + should_skip = skip_sum.item() > 0 + except Exception: + # Fallback if reduce isn't available for any reason + should_skip = local_skip == 1 + + if should_skip: + if accelerator.is_main_process and local_skip: + logger.warning(f"Skipping step {global_step} (required multimodal batch missing) [synced]") + continue + + device = accelerator.device + batch_size_v2s = 0 + input_ids_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_v2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + # Text-to-image samples + batch_size_t2i = 0 + mask_prob = torch.tensor(0.0, device=device) + t2i_masks = torch.empty((0, 1), dtype=torch.long, device=device) + input_ids_t2i = torch.empty((0, 1), dtype=torch.long, device=device) + labels_t2i = torch.empty((0, 1), dtype=torch.long, device=device) + batch_size_i2i = 0 + mask_prob_i2i = torch.tensor(0.0, device=device) + input_ids_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + labels_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + attention_masks_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + + if selected_x2i_branch == "t2i" and t2i_batch is not None: + t2i_texts = t2i_batch["t2i"].get("texts", []) + t2i_images_tensor = t2i_batch["t2i"].get("images") + if isinstance(t2i_images_tensor, torch.Tensor) and t2i_images_tensor.shape[0] > 0: + t2i_images_tensor = t2i_images_tensor.to(device, non_blocking=True) + batch_size_t2i = t2i_images_tensor.shape[0] + ( + input_ids_t2i, + labels_t2i, + mask_prob, + _, + t2i_masks, + ) = prepare_inputs_and_labels(t2i_images_tensor, t2i_texts, config.training.min_masking_rate) + input_ids_t2i = input_ids_t2i.to(device, non_blocking=True) + labels_t2i = labels_t2i.to(device, non_blocking=True) + t2i_masks = t2i_masks.to(device, non_blocking=True) + if mask_prob.device != device: + mask_prob = mask_prob.to(device) + + if selected_x2i_branch == "i2i" and i2i_batch is not None: + i2i_prompts = i2i_batch["i2i"].get("prompts", []) + i2i_source_tensor = i2i_batch["i2i"].get("source_images") + i2i_target_tensor = i2i_batch["i2i"].get("target_images") + if ( + isinstance(i2i_source_tensor, torch.Tensor) + and isinstance(i2i_target_tensor, torch.Tensor) + and i2i_source_tensor.shape[0] > 0 + and i2i_target_tensor.shape[0] > 0 + ): + i2i_source_tensor = i2i_source_tensor.to(device, non_blocking=True) + i2i_target_tensor = i2i_target_tensor.to(device, non_blocking=True) + batch_size_i2i = i2i_source_tensor.shape[0] + ( + input_ids_i2i, + labels_i2i, + mask_prob_i2i, + attention_masks_i2i, + ) = prepare_inputs_and_labels_for_i2i( + i2i_source_tensor, + i2i_target_tensor, + i2i_prompts, + is_train=True, + ) + input_ids_i2i = input_ids_i2i.to(device, non_blocking=True) + labels_i2i = labels_i2i.to(device, non_blocking=True) + attention_masks_i2i = attention_masks_i2i.to(device, non_blocking=True) + if mask_prob_i2i.device != device: + mask_prob_i2i = mask_prob_i2i.to(device) + + # Language modeling samples + batch_size_lm = 0 + input_ids_lm = torch.empty((0, 1), dtype=torch.long, device=device) + labels_lm = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_lm = torch.empty((0, 1), dtype=torch.float32, device=device) + if train_dataloader_lm is not None: + lm_batch = batch.get("lm_flow") + if lm_batch is not None: + texts_lm = lm_batch["input_ids"] + batch_size_lm = len(texts_lm) + max_seq_for_lm = input_ids_t2i.shape[1] if batch_size_t2i > 0 else preproc_config.max_seq_length + input_ids_lm, labels_lm, p_mask_lm = prepare_inputs_and_labels_for_text(texts_lm, max_seq_for_lm) + input_ids_lm = input_ids_lm.to(device, non_blocking=True) + labels_lm = labels_lm.to(device, non_blocking=True) + p_mask_lm = p_mask_lm.to(device, non_blocking=True) + + if isinstance(v2t_batch, dict): + video_tensor_text_raw = v2t_batch.get("video") + texts_vid = v2t_batch.get("captions", []) + else: + video_tensor_text_raw = None + texts_vid = [] + + if isinstance(v2s_batch, dict): + video_tensor_speech_raw = v2s_batch.get("video") + speech_items = v2s_batch.get("speech", []) + else: + video_tensor_speech_raw = None + speech_items = [] + + batch_size_v2t = video_tensor_text_raw.shape[0] if isinstance(video_tensor_text_raw, torch.Tensor) else 0 + batch_size_v2s = len(speech_items) + + video_tensor_text = ( + video_tensor_text_raw.to(device, non_blocking=True) + if isinstance(video_tensor_text_raw, torch.Tensor) + else torch.empty((0, 1, 1, 1, 1), device=device) + ) + video_tensor_speech = ( + video_tensor_speech_raw.to(device, non_blocking=True) + if isinstance(video_tensor_speech_raw, torch.Tensor) + else torch.empty((0, 1, 1, 1, 1), device=device) + ) + + s2s_batch = batch.get("s2s_flow") + batch_size_s2s = 0 + if s2s_batch is not None: + batch_size_s2s = len(s2s_batch.get("emova_sft", [])) + len(s2s_batch.get("instructs2s", [])) + + mmu_batch = batch.get("mmu_flow") + batch_size_mmu = 0 + image_tensor_list = [] + texts_image = [] + if mmu_batch is not None: + image_tensor_list = mmu_batch.get("images", []) + texts_image = mmu_batch.get("text", []) + batch_size_mmu = len(image_tensor_list) + + s2t_flow = batch.get("s2t_flow", {}) + t2s_flow = batch.get("t2s_flow", {}) + audio_paths_s2t_raw, texts_s2t_raw = s2t_flow.get("audio_path", []), s2t_flow.get("text", []) + audio_paths_t2s_raw, texts_t2s_raw = t2s_flow.get("audio_path", []), t2s_flow.get("text", []) + audio_tokens_s2t_raw = s2t_flow.get("audio_tokens", []) + audio_tokens_t2s_raw = t2s_flow.get("audio_tokens", []) + + audio_paths_s2t, audio_tokens_s2t, texts_s2t, skipped_s2t = _prepare_audio_flow( + audio_paths_s2t_raw, + audio_tokens_s2t_raw, + texts_s2t_raw, + ) + audio_paths_t2s, audio_tokens_t2s, texts_t2s, skipped_t2s = _prepare_audio_flow( + audio_paths_t2s_raw, + audio_tokens_t2s_raw, + texts_t2s_raw, + ) + batch_size_s2t = len(audio_paths_s2t) + batch_size_t2s_text = len(audio_paths_t2s) + if require_cached_audio_tokens and accelerator.is_main_process: + skipped_total = skipped_s2t + skipped_t2s + if skipped_total and (global_step % 50 == 0): + logger.info( + "Skipped %d speech samples lacking cached tokens (s2t=%d, t2s=%d).", + skipped_total, + skipped_s2t, + skipped_t2s, + ) + + if batch_size_s2t > 0 and batch_size_t2s_text > 0: + drop_t2s = (global_step % 2) == 0 + + if drop_t2s: + audio_paths_t2s = [] + texts_t2s = [] + audio_tokens_t2s = [] + batch_size_t2s_text = 0 + else: + audio_paths_s2t = [] + texts_s2t = [] + audio_tokens_s2t = [] + batch_size_s2t = 0 + + batch_size_s2t = len(audio_paths_s2t) + batch_size_t2s_text = len(audio_paths_t2s) + + active_x2i_branch = selected_x2i_branch or "none" + logger.info( + f"x2i_branch: {active_x2i_branch}, batch_size_t2i: {batch_size_t2i}, batch_size_i2i: {batch_size_i2i}, batch_size_lm: {batch_size_lm}, " + f"batch_size_v2t: {batch_size_v2t}, batch_size_v2s: {batch_size_v2s}, batch_size_t2s: {batch_size_t2s_text}, " + f"batch_size_s2t: {batch_size_s2t}, batch_size_s2s: {batch_size_s2s}, batch_size_mmu: {batch_size_mmu}" + ) + offset = speech_vocab_start + + data_time_m.update(time.time() - end) + + failure_messages = [] + step_failed = False + + input_ids_vid = torch.empty((0, 1), dtype=torch.long, device=device) + labels_vid = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_vid = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_vid = torch.empty((0, 1), dtype=torch.long, device=device) + + input_ids_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_v2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + + if batch_size_v2t > 0: + video_token_list = [] + for vid_idx, video in enumerate(video_tensor_text): + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens + len(uni_prompting.text_tokenizer) + video_token_list.append(tokens.view(-1)) + + if not step_failed and video_token_list: + video_tokens_text = torch.stack(video_token_list, dim=0) + + texts_with_prompt: List[str] + is_vid_inst = False + if texts_vid and isinstance(texts_vid[0], (list, tuple)) and isinstance(texts_vid[0][0], dict): + is_vid_inst = True + vid_inst_prompt: List[str] = [] + vid_inst_answer: List[str] = [] + for conv in texts_vid: + human_msg = "" + assistant_msg = "" + for turn in conv: + role = turn.get("from") + value = turn.get("value", "") + if role == "human": + human_msg = value.replace("\n", "") + elif role == "gpt": + assistant_msg = value + vid_inst_prompt.append(human_msg) + vid_inst_answer.append(assistant_msg) + texts_with_prompt = [ + "<|start_header_id|>user<|end_header_id|>\n" + f"{vid_inst_prompt[i]}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n" + f"{vid_inst_answer[i]}" + for i in range(len(vid_inst_answer)) + ] + else: + prompt_v2t_selected = random.choice(V2T_INSTRUCTION) + texts_with_prompt = [ + "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt_v2t_selected}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n" + f"{text if isinstance(text, str) else str(text)}" + for text in texts_vid + ] + + input_ids_vid_tmp, prompt_masks_vid, labels_vid_tmp = uni_prompting((video_tokens_text, texts_with_prompt), 'v2t') + input_ids_vid_tmp, labels_vid_tmp, p_mask_vid, answer_lengths_vid = prepare_inputs_and_labels_for_mmu( + input_ids_vid_tmp, prompt_masks_vid, labels_vid_tmp + ) + input_ids_vid = input_ids_vid_tmp.to(device, non_blocking=True) + labels_vid = labels_vid_tmp.to(device, non_blocking=True) + p_mask_vid = p_mask_vid.to(device, non_blocking=True) + answer_lengths_vid = answer_lengths_vid.to(device, non_blocking=True) + else: + batch_size_v2t = 0 + + if batch_size_v2s > 0 and not step_failed: + all_audio_tokens: list[torch.Tensor] = [] + for speech_entry in speech_items: + if isinstance(speech_entry, torch.Tensor): + tokens = speech_entry.to(device, non_blocking=True) + else: + tokens, err = safe_audio_encode(speech_entry, "v2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + video_token_list_v2s: list[torch.Tensor] = [] + if not step_failed: + for vid_idx, video in enumerate(video_tensor_speech): + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens + len(uni_prompting.text_tokenizer) + video_token_list_v2s.append(tokens.view(-1)) + + if not step_failed and all_audio_tokens and video_token_list_v2s: + video_tokens_v2s = torch.stack(video_token_list_v2s, dim=0) + prompts_v2s = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(V2S_INSTRUCTION)}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for _ in range(batch_size_v2s) + ] + input_ids_v2s_tmp, prompt_masks_v2s, labels_v2s_tmp = uni_prompting( + (video_tokens_v2s, prompts_v2s, all_audio_tokens), 'v2s_ip' + ) + input_ids_v2s_tmp, labels_v2s_tmp, p_mask_v2s_tmp, answer_lengths_v2s_tmp = prepare_inputs_and_labels_for_v2s( + input_ids_v2s_tmp, prompt_masks_v2s, labels_v2s_tmp + ) + input_ids_v2s = input_ids_v2s_tmp.to(device, non_blocking=True) + labels_v2s = labels_v2s_tmp.to(device, non_blocking=True) + p_mask_v2s = p_mask_v2s_tmp.to(device, non_blocking=True) + answer_lengths_v2s = answer_lengths_v2s_tmp.to(device, non_blocking=True) + else: + batch_size_v2s = 0 + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed and batch_size_s2t > 0: + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + if not audio_tokens_s2t: + audio_tokens_s2t = [None] * len(audio_paths_s2t) + elif len(audio_tokens_s2t) < len(audio_paths_s2t): + audio_tokens_s2t = list(audio_tokens_s2t) + [None] * (len(audio_paths_s2t) - len(audio_tokens_s2t)) + + for path, cached_tokens in zip(audio_paths_s2t, audio_tokens_s2t): + source = cached_tokens if cached_tokens is not None else path + tokens, err = safe_audio_encode(source, "s2t") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + # Preserve trailing EOS tokens in s2t targets for explicit prediction. + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + else: + input_ids_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_s2t = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed and batch_size_t2s_text > 0: + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + if not audio_tokens_t2s: + audio_tokens_t2s = [None] * len(audio_paths_t2s) + elif len(audio_tokens_t2s) < len(audio_paths_t2s): + audio_tokens_t2s = list(audio_tokens_t2s) + [None] * (len(audio_paths_t2s) - len(audio_tokens_t2s)) + + for path, cached_tokens in zip(audio_paths_t2s, audio_tokens_t2s): + source = cached_tokens if cached_tokens is not None else path + tokens, err = safe_audio_encode(source, "t2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + # Chat-style instruction formatting for T2S training + prompt = random.choice(prompt_t2s) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in texts_t2s + ] + + # input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + else: + input_ids_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_t2s = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + + audio_user_ids_s2s: list[torch.Tensor] = [] + audio_asst_ids_s2s: list[torch.Tensor] = [] + image_token_blocks_s2s: list[Optional[torch.Tensor]] = [] + input_ids_s2s = None + labels_s2s = None + p_mask_s2s = None + answer_lengths_s2s = None + + if not step_failed and batch_size_s2s > 0 and s2s_batch is not None: + s2s_sample_counter = 0 + + emova_samples = s2s_batch.get("emova_sft", []) + for sample_idx, (usr_ids, asst_ids, image_like) in enumerate(emova_samples): + usr_tensor = torch.tensor(usr_ids, dtype=torch.long, device=accelerator.device).unsqueeze(0) + asst_tensor = torch.tensor(asst_ids, dtype=torch.long, device=accelerator.device).unsqueeze(0) + + audio_user_ids_s2s.append(usr_tensor + offset) + audio_asst_ids_s2s.append(asst_tensor + offset) + + decoded_payload = maybe_decode_image(image_like) + if isinstance(decoded_payload, (list, tuple)): + token_payloads = [] + for local_img in decoded_payload: + if local_img is None: + continue + pixel_values = image_transform(local_img, resolution=preproc_config.resolution).to(accelerator.device) + if pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + image_tokens_raw, err = safe_image_get_code(pixel_values, s2s_sample_counter) + if err is not None: + failure_messages.append(err) + step_failed = True + break + image_tokens = image_tokens_raw + len(uni_prompting.text_tokenizer) + token_payloads.append(image_tokens.view(1, -1).to(accelerator.device, non_blocking=True)) + s2s_sample_counter += 1 + if step_failed: + break + image_token_blocks_s2s.append(token_payloads if token_payloads else None) + else: + if decoded_payload is not None: + pixel_values = image_transform(decoded_payload, resolution=preproc_config.resolution).to(accelerator.device) + if pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + image_tokens_raw, err = safe_image_get_code(pixel_values, s2s_sample_counter) + if err is not None: + failure_messages.append(err) + step_failed = True + break + image_tokens = image_tokens_raw + len(uni_prompting.text_tokenizer) + image_token_blocks_s2s.append(image_tokens.view(1, -1).to(accelerator.device, non_blocking=True)) + s2s_sample_counter += 1 + else: + image_token_blocks_s2s.append(None) + + instruct_samples = [] if step_failed else s2s_batch.get("instructs2s", []) + if not step_failed: + for sample_idx, (usr_wav, asst_wav, _) in enumerate(instruct_samples): + user_tokens, err_usr = safe_audio_encode(usr_wav, "s2s-user") + if err_usr is not None: + failure_messages.append(err_usr) + step_failed = True + break + asst_tokens, err_asst = safe_audio_encode(asst_wav, "s2s-assistant") + if err_asst is not None: + failure_messages.append(err_asst) + step_failed = True + break + + user_tokens = user_tokens.to(accelerator.device, non_blocking=True) + asst_tokens = asst_tokens.to(accelerator.device, non_blocking=True) + if user_tokens.size(-1) > config.dataset.preprocessing.max_aud_length: + duration_tokens = config.dataset.preprocessing.max_aud_length + user_tokens = user_tokens[..., :duration_tokens] + if asst_tokens.size(-1) > config.dataset.preprocessing.max_aud_length: + duration_tokens = config.dataset.preprocessing.max_aud_length + asst_tokens = asst_tokens[..., :duration_tokens] + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + if asst_tokens.dim() == 1: + asst_tokens = asst_tokens.unsqueeze(0) + audio_user_ids_s2s.append(user_tokens + offset) + audio_asst_ids_s2s.append(asst_tokens + offset) + image_token_blocks_s2s.append(None) + s2s_sample_counter += 1 + + if not step_failed and audio_user_ids_s2s: + input_ids_s2s, prompt_masks_s2s, labels_s2s = uni_prompting( + (audio_user_ids_s2s, audio_asst_ids_s2s, image_token_blocks_s2s), + 's2s_ip' + ) + + input_ids_s2s, labels_s2s, p_mask_s2s, answer_lengths_s2s = prepare_inputs_and_labels_for_s2s( + input_ids_s2s, + prompt_masks_s2s, + labels_s2s, + ) + + if ( + answer_lengths_s2s is not None + and answer_lengths_s2s.numel() > 0 + and accelerator.is_main_process + ): + per_sample_lengths = answer_lengths_s2s[:, 0].detach().cpu() + lengths_list = [int(length) for length in per_sample_lengths.tolist()] + stats_msg = ( + f"min={int(per_sample_lengths.min().item())}, " + f"max={int(per_sample_lengths.max().item())}, " + f"mean={per_sample_lengths.float().mean().item():.2f}" + ) + logger.info("S2S answer lengths (no pad): %s | %s", lengths_list, stats_msg) + + if input_ids_s2s is None: + device = accelerator.device + input_ids_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_s2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + + input_ids_mmu = None + labels_mmu = None + p_mask_mmu = None + answer_lengths_mmu = None + + if not step_failed and batch_size_mmu > 0: + batch_image_ids_list = [] + batch_text_ids = [] + + for b_idx, image_list in enumerate(image_tensor_list): + per_img_ids = [] + for j, img in enumerate(image_list): + tok, err = safe_image_get_code( + img.to(accelerator.device, non_blocking=True), + sample_index=j + ) + if err is not None: + failure_messages.append(err) + step_failed = True + break + + tok = tok.to(accelerator.device, non_blocking=True).view(-1).long() + tok = tok + len(uni_prompting.text_tokenizer) + per_img_ids.append(tok) + + if step_failed: + break + + batch_image_ids_list.append(per_img_ids) + text_ids = uni_prompting.text_tokenizer.encode(texts_image[b_idx], add_special_tokens=False) + batch_text_ids.append(text_ids) + + if not step_failed: + input_ids_mmu, prompt_masks_mmu, labels_mmu = uni_prompting.mmu_mult_prompt( + batch_image_ids_list=batch_image_ids_list, + batch_text_ids=batch_text_ids, + ) + + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths_mmu + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks_mmu, labels_mmu) + + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + labels_mmu = labels_mmu.to(accelerator.device, non_blocking=True) + p_mask_mmu = p_mask_mmu.to(accelerator.device, non_blocking=True) + answer_lengths_mmu = answer_lengths_mmu.to(accelerator.device, non_blocking=True) + + if batch_size_mmu == 0 or input_ids_mmu is None: + input_ids_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_mmu = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + if not step_failed: + total_batch_size_t2s = batch_size_t2s_text + else: + total_batch_size_t2s = batch_size_t2s_text + + failure_tensor = torch.tensor(1 if step_failed else 0, device=accelerator.device, dtype=torch.int32) + failure_sum = accelerator.reduce(failure_tensor, reduction='sum') + if failure_sum.item() > 0: + if accelerator.is_main_process and failure_messages: + for msg in failure_messages: + logger.warning(f"Skipping global step {global_step} due to preprocessing failure: {msg}") + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + # -------------------------------------------------------------------------------- + # for name, tensor in [ + # ("t2i", input_ids_t2i), + # ("i2i", input_ids_i2i), + # ("lm", input_ids_lm), + # ("mmu", input_ids_mmu), + # ("vid", input_ids_vid), + # ("s2t", input_ids_s2t), + # ("s2s", input_ids_s2s), + # ("t2s", input_ids_t2s), + # ]: + # if tensor is not None: + # print(f"{name:>4}: shape={getattr(tensor, 'shape', None)}, len={len(tensor) if hasattr(tensor, '__len__') else 'N/A'}") + + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + seq_lengths = [] + if input_ids_t2i.shape[0] > 0: + seq_lengths.append(input_ids_t2i.shape[1]) + if input_ids_i2i.shape[0] > 0: + seq_lengths.append(input_ids_i2i.shape[1]) + if input_ids_lm.shape[0] > 0: + seq_lengths.append(input_ids_lm.shape[1]) + seq_lengths.extend([ + input_ids_vid.shape[1], + input_ids_v2s.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1], + ]) + if input_ids_s2s.shape[0] > 0: + seq_lengths.append(input_ids_s2s.shape[1]) + if input_ids_mmu.shape[0] > 0: + seq_lengths.append(input_ids_mmu.shape[1]) + max_len = max(seq_lengths) + + # 3. Pad all tensors to the max_len + input_ids_t2i = pad_tensor(input_ids_t2i, max_len, pad_token_id) + labels_t2i = pad_tensor(labels_t2i, max_len, -100) + if t2i_masks.shape[0] > 0: + t2i_masks = pad_tensor(t2i_masks.long(), max_len, 0) + else: + t2i_masks = torch.empty((0, max_len), dtype=torch.long, device=device) + + input_ids_i2i = pad_tensor(input_ids_i2i, max_len, pad_token_id) + labels_i2i = pad_tensor(labels_i2i, max_len, -100) + if attention_masks_i2i.shape[0] > 0: + attention_masks_i2i = pad_tensor(attention_masks_i2i.long(), max_len, 0) + else: + attention_masks_i2i = torch.empty((0, max_len), dtype=torch.long, device=device) + + + input_ids_lm = pad_tensor(input_ids_lm, max_len, pad_token_id) + labels_lm = pad_tensor(labels_lm, max_len, -100) + p_mask_lm = pad_tensor(p_mask_lm, max_len, 1.0) + + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_v2s = pad_tensor(input_ids_v2s, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + input_ids_s2s = pad_tensor(input_ids_s2s, max_len, pad_token_id) + input_ids_mmu = pad_tensor(input_ids_mmu, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_v2s = pad_tensor(labels_v2s, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + labels_s2s = pad_tensor(labels_s2s, max_len, -100) + labels_mmu = pad_tensor(labels_mmu, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_v2s = pad_tensor(p_mask_v2s, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + p_mask_s2s = pad_tensor(p_mask_s2s, max_len, 1.0) + p_mask_mmu = pad_tensor(p_mask_mmu, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_v2s = pad_answer_lengths(answer_lengths_v2s, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + answer_lengths_s2s = pad_answer_lengths(answer_lengths_s2s, max_len) + answer_lengths_mmu = pad_answer_lengths(answer_lengths_mmu, max_len) + + input_ids = torch.cat(( + input_ids_t2i, + input_ids_i2i, + input_ids_lm, + input_ids_mmu, + input_ids_vid, + input_ids_v2s, + input_ids_s2t, + input_ids_s2s, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_t2i, + labels_i2i, + labels_lm, + labels_mmu, + labels_vid, + labels_v2s, + labels_s2t, + labels_s2s, + labels_t2s + ), dim=0) + + # w/o texts and images + if batch_size_lm == 0: + p_mask_lm = torch.empty((0, max_len), dtype=torch.float32, device=device) + if batch_size_t2i == 0 and t2i_masks.shape[0] == 0: + t2i_masks = torch.empty((0, max_len), dtype=torch.long, device=device) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + logger.info("Input ids shape: {}".format(input_ids.shape)) + # with accelerator.accumulate(model): + logits, loss_t2i, loss_i2i, loss_lm, loss_mmu, loss_vid, loss_v2s, loss_s2t, loss_s2s, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_i2i=batch_size_i2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_v2s=batch_size_v2s, + batch_size_s2t=batch_size_s2t, + batch_size_s2s=batch_size_s2s, + batch_size_t2s=total_batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + attention_masks_i2i=attention_masks_i2i, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_v2s=p_mask_v2s, + p_mask_s2t=p_mask_s2t, + p_mask_s2s=p_mask_s2s, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_v2s=answer_lengths_v2s, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_s2s=answer_lengths_s2s, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + text_vocab_size_override=len(uni_prompting.text_tokenizer), + ) + + if batch_size_t2i == 0: + loss_t2i = loss_t2i.new_zeros(()) + if batch_size_i2i == 0: + loss_i2i = loss_i2i.new_zeros(()) + + # Gather the losses across all processes for logging (use reduce to avoid shape mismatches) + avg_loss_t2i = accelerator.reduce(loss_t2i, reduction='mean') + avg_loss_i2i = accelerator.reduce(loss_i2i, reduction='mean') + avg_loss_lm = accelerator.reduce(loss_lm, reduction='mean') + avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + avg_loss_vid = accelerator.reduce(loss_vid, reduction='mean') + avg_loss_v2s = accelerator.reduce(loss_v2s, reduction='mean') + avg_loss_s2t = accelerator.reduce(loss_s2t, reduction='mean') + avg_loss_s2s = accelerator.reduce(loss_s2s, reduction='mean') + if not torch.isfinite(loss_t2s): + if labels_t2s.numel() > 0: + speech_vocab_end = speech_vocab_start + audio_codebook_size + valid_mask = labels_t2s != -100 + if valid_mask.any(): + labels_valid = labels_t2s[valid_mask] + below_count = (labels_valid < speech_vocab_start).sum().item() + above_count = (labels_valid >= speech_vocab_end).sum().item() + labels_min = labels_valid.min().item() + labels_max = labels_valid.max().item() + else: + below_count = above_count = 0 + labels_min = labels_max = -100 + p_mask_min = p_mask_t2s.min().item() if p_mask_t2s.numel() > 0 else float("nan") + ans_len_min = ( + answer_lengths_t2s.min().item() + if answer_lengths_t2s.numel() > 0 + else float("nan") + ) + accelerator.print( + "[t2s NaN debug] " + f"rank={accelerator.process_index} step={global_step} " + f"slice=({speech_vocab_start}, {speech_vocab_end}) " + f"labels_min={labels_min} labels_max={labels_max} " + f"below_slice={below_count} above_slice={above_count} " + f"p_mask_min={p_mask_min} answer_len_min={ans_len_min}" + ) + accelerator.print( + f"[rank {accelerator.process_index}] t2s loss became NaN/Inf at global step {global_step} " + f"(local value: {loss_t2s.item()})" + ) + logger.warning( + "[rank %s] t2s loss became NaN/Inf at global step %s (local value: %s)", + accelerator.process_index, + global_step, + loss_t2s.item(), + ) + avg_loss_t2s = accelerator.reduce(loss_t2s, reduction='mean') + if not torch.isfinite(avg_loss_t2s): + accelerator.print( + f"[rank {accelerator.process_index}] reduced t2s loss NaN/Inf at global step {global_step} " + f"(value after all-reduce: {avg_loss_t2s.item()})" + ) + if accelerator.is_main_process: + logger.warning( + "Reduced t2s loss became NaN/Inf at global step %s (value after all-reduce: %s)", + global_step, + avg_loss_t2s.item(), + ) + + mmu_coeff = getattr(config.training, "mmu_coeff", 0.0) + i2i_coeff = getattr(config.training, "i2i_coeff", config.training.t2i_coeff) + s2s_coeff = getattr(config.training, "s2s_coeff", config.training.t2s_coeff) + v2s_coeff = getattr(config.training, "v2s_coeff", config.training.t2s_coeff) + loss = ( + config.training.t2i_coeff * loss_t2i + + i2i_coeff * loss_i2i + + config.training.lm_coeff * loss_lm + + mmu_coeff * loss_mmu + + config.training.v2t_coeff * loss_vid + + v2s_coeff * loss_v2s + + config.training.s2t_coeff * loss_s2t + + s2s_coeff * loss_s2s + + config.training.t2s_coeff * loss_t2s + ) + + if batch_size_t2i > 0: + local_masking_rate = mask_prob.float().mean() + else: + local_masking_rate = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate = accelerator.reduce(local_masking_rate, reduction='mean') + + if batch_size_i2i > 0: + local_masking_rate_i2i = mask_prob_i2i.float().mean() + else: + local_masking_rate_i2i = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_i2i = accelerator.reduce(local_masking_rate_i2i, reduction='mean') + + if batch_size_s2s > 0 and p_mask_s2s.numel() > 0: + local_masking_rate_s2s = p_mask_s2s.float().mean() + else: + local_masking_rate_s2s = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_s2s = accelerator.reduce(local_masking_rate_s2s, reduction='mean') + + if batch_size_v2s > 0 and p_mask_v2s.numel() > 0: + local_masking_rate_v2s = p_mask_v2s.float().mean() + else: + local_masking_rate_v2s = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_v2s = accelerator.reduce(local_masking_rate_v2s, reduction='mean') + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "avg_masking_rate_i2i": avg_masking_rate_i2i.item(), + "avg_masking_rate_v2s": avg_masking_rate_v2s.item(), + "avg_masking_rate_s2s": avg_masking_rate_s2s.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + + loss_entries = [ + ("step_loss_t2i", avg_loss_t2i), + ("step_loss_i2i", avg_loss_i2i), + ("step_loss_lm", avg_loss_lm), + ("step_loss_mmu", avg_loss_mmu), + ("step_loss_vid", avg_loss_vid), + ("step_loss_v2s", avg_loss_v2s), + ("step_loss_s2t", avg_loss_s2t), + ("step_loss_s2s", avg_loss_s2s), + ("step_loss_t2s", avg_loss_t2s), + ] + + loss_log_parts = [] + for key, value in loss_entries: + loss_value = value.item() + if loss_value != 0.0: + logs[key] = loss_value + loss_log_parts.append(f"{key.replace('step_', '').capitalize()}: {loss_value:0.4f}") + + accelerator.log(logs, step=global_step + 1) + + loss_str = " ".join(loss_log_parts) + logger.info( + "Step: %d %s Data (t): %.4f, %.2f/s/gpu Batch (t): %.4f LR: %.6f" + % ( + global_step + 1, + loss_str, + data_time_m.val, + samples_per_second_per_gpu, + batch_time_m.val, + lr_scheduler.get_last_lr()[0], + ) + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # if global_step == 0 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + if (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_omada_inst_acc.py b/MMaDA/training/train_omada_inst_acc.py new file mode 100644 index 0000000000000000000000000000000000000000..2235adf986610d01cef44f885baf54785f6a1601 --- /dev/null +++ b/MMaDA/training/train_omada_inst_acc.py @@ -0,0 +1,4472 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import binascii +import hashlib +import os +import sys +import warnings +import subprocess +import tempfile +os.environ["FFMPEG_LOG_LEVEL"] = "error" +warnings.filterwarnings("ignore") + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import logging +import math +import torch.nn.functional as F +import shutil +import time +import cv2 +import glob +import random +import contextlib +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union, Dict, Any, List, Iterator, NamedTuple +from collections.abc import Sequence +import csv +import numpy as np +from PIL import Image +from io import BytesIO +from omegaconf import OmegaConf, DictConfig +import wandb +import torch +from torch.optim import AdamW +import torch.multiprocessing as mp +from dataclasses import dataclass, field + +try: + cv2.utils.logging.setLogLevel(cv2.utils.logging.LOG_LEVEL_ERROR) +except AttributeError: + warnings.filterwarnings("ignore", category=FutureWarning) + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import ( + SpeechTextDataset, + MixedSpeechTextDataset, + Speech2SpeechDataset, + TextImageInterleavedDataset, + load_video_mp4, + VideoCaptionDataset, + VideoSpeechDataset, + S2T_INSTRUCTION, + T2S_INSTRUCTION, + V2S_INSTRUCTION, + s2s_collate_fn, +) +# import librosa + +from training.data import ( + Text2ImageDataset, + HQEditX2IDataset, + CombinedX2IDataset, + HFInstructionTextDataset, + TextToImage2MDataset, + OpenImageI2IDataset, +) +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +cv2.setNumThreads(0) +torch.set_num_threads(1) +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["NCCL_ASYNC_ERROR_HANDLING"]= "1" +os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" +os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]= "1" +os.environ["TORCH_NCCL_BLOCKING_WAIT"]= "1" + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + +is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + + +@dataclass +class TaskBatchPayload: + """Container for task-specific batch data prior to global padding/aggregation.""" + + name: str + input_ids: torch.Tensor + labels: torch.Tensor + batch_size: int + attention_mask: Optional[torch.Tensor] = None + p_mask: Optional[torch.Tensor] = None + answer_lengths: Optional[torch.Tensor] = None + weight: float = 1.0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_device(self, device: torch.device) -> "TaskBatchPayload": + """Ensure tensors are on the target device.""" + self.input_ids = self.input_ids.to(device, non_blocking=True) + self.labels = self.labels.to(device, non_blocking=True) + if self.attention_mask is not None: + self.attention_mask = self.attention_mask.to(device, non_blocking=True) + if self.p_mask is not None: + self.p_mask = self.p_mask.to(device, non_blocking=True) + if self.answer_lengths is not None: + self.answer_lengths = self.answer_lengths.to(device, non_blocking=True) + for key, value in list(self.metadata.items()): + if torch.is_tensor(value): + self.metadata[key] = value.to(device, non_blocking=True) + return self + + +def _make_payload( + name: str, + input_ids: torch.Tensor, + labels: torch.Tensor, + batch_size: int, + *, + attention_mask: Optional[torch.Tensor] = None, + p_mask: Optional[torch.Tensor] = None, + answer_lengths: Optional[torch.Tensor] = None, + weight: float = 1.0, + metadata: Optional[Dict[str, Any]] = None, +) -> TaskBatchPayload: + return TaskBatchPayload( + name=name, + input_ids=input_ids, + labels=labels, + batch_size=batch_size, + attention_mask=attention_mask, + p_mask=p_mask, + answer_lengths=answer_lengths, + weight=weight, + metadata=metadata or {}, + ) + + +def _broadcast_choice(value: int, accelerator: Accelerator, src: int = 0) -> int: + """Synchronize an integer choice across all ranks without relying on NCCL shared memory.""" + if accelerator.num_processes == 1: + return value + choice_tensor = torch.tensor( + [value if accelerator.process_index == src else 0], + device=accelerator.device, + dtype=torch.int32, + ) + gathered = accelerator.gather(choice_tensor) + return int(gathered[src].item()) + + +def _patch_shared_memory_tracker() -> None: + """Stop Python's resource_tracker from eagerly unlinking shared memory segments.""" + try: + from multiprocessing import resource_tracker + except ImportError: + return + + if getattr(resource_tracker, "_omada_shared_memory_patch", False): + return + + orig_register = resource_tracker.register + orig_unregister = resource_tracker.unregister + cleanup_funcs = getattr(resource_tracker, "_CLEANUP_FUNCS", {}) + orig_cleanup = cleanup_funcs.get("shared_memory") + + def _safe_register(name: str, rtype: str) -> None: + if rtype == "shared_memory": + return + orig_register(name, rtype) + + def _safe_unregister(name: str, rtype: str) -> None: + if rtype == "shared_memory": + return + orig_unregister(name, rtype) + + def _noop(*_args, **_kwargs) -> None: + return + + resource_tracker.register = _safe_register # type: ignore[assignment] + resource_tracker.unregister = _safe_unregister # type: ignore[assignment] + if orig_cleanup is not None: + cleanup_funcs["shared_memory"] = _noop + resource_tracker._omada_shared_memory_patch = True # type: ignore[attr-defined] + + +def _configure_multiprocessing() -> None: + """Configure torch multiprocessing to avoid shared-memory exhaustion on multi-worker loads.""" + try: + _patch_shared_memory_tracker() + except Exception as exc: # pragma: no cover - best-effort patching + logger.warning("Failed to apply shared memory tracker patch: %s", exc) + + try: + mp.set_sharing_strategy("file_descriptor") + except RuntimeError as exc: + logger.warning("Failed to set multiprocessing sharing strategy to 'file_descriptor': %s", exc) + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + 'audio_tokens': [item.get('audio_tokens') for item in batch], + } + + +def _empty_audio_batch() -> dict[str, list[Any]]: + """Utility to create an empty speech batch placeholder.""" + return { + "audio_path": [], + "text": [], + "audio_tokens": [], + } + + +def collate_fn_mmu_mult(batch): + return { + 'images': [item['images'] for item in batch], + 'text': [item['text'] for item in batch], + } + + +def collate_fn_x2i(batch): + t2i_texts: list[str] = [] + t2i_images: list[torch.Tensor] = [] + + i2i_prompts: list[str] = [] + i2i_source_images: list[torch.Tensor] = [] + i2i_target_images: list[torch.Tensor] = [] + + ref_image: Optional[torch.Tensor] = None + + has_i2i_sample = False + + for sample in batch: + input_prompt = sample.get("input_prompt") + output_prompt = sample.get("output_prompt") + edit_prompt = sample.get("edit_prompt") + inverse_prompt = sample.get("inverse_prompt") + input_image = sample.get("input_image") + output_image = sample.get("output_image") + + if isinstance(input_image, torch.Tensor) and ref_image is None: + ref_image = input_image + if isinstance(output_image, torch.Tensor) and ref_image is None: + ref_image = output_image + + has_edit_pair = ( + isinstance(input_image, torch.Tensor) + and isinstance(output_image, torch.Tensor) + and ( + (edit_prompt and edit_prompt.strip()) + or (inverse_prompt and inverse_prompt.strip()) + ) + ) + + if has_edit_pair: + has_i2i_sample = True + edit_candidates: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + if edit_prompt and edit_prompt.strip(): + edit_candidates.append((edit_prompt, input_image, output_image)) + if inverse_prompt and inverse_prompt.strip(): + edit_candidates.append((inverse_prompt, output_image, input_image)) + + if edit_candidates: + chosen_prompt, chosen_src, chosen_tgt = random.choice(edit_candidates) + i2i_prompts.append(chosen_prompt) + i2i_source_images.append(chosen_src) + i2i_target_images.append(chosen_tgt) + continue + else: + if input_prompt and isinstance(input_image, torch.Tensor): + t2i_texts.append(input_prompt) + t2i_images.append(input_image) + elif output_prompt and isinstance(output_image, torch.Tensor): + t2i_texts.append(output_prompt) + t2i_images.append(output_image) + + if has_i2i_sample: + # i2iź°€ ķ•˜ė‚˜ė¼ė„ ģžˆģœ¼ė©“ ģ“ė²ˆ ė°°ģ¹˜ėŠ” i2i ģ „ģš©ģœ¼ė”œ ģ‚¬ģš©ķ•˜ź³  t2iėŠ” 비움 + t2i_texts = [] + t2i_images = [] + + def stack_images(images: list[torch.Tensor]) -> torch.Tensor: + if images: + return torch.stack(images, dim=0) + if ref_image is not None: + c, h, w = ref_image.shape[-3:] + return torch.empty((0, c, h, w), dtype=ref_image.dtype) + return torch.empty((0, 3, 0, 0), dtype=torch.float32) + + return { + "t2i": { + "texts": t2i_texts, + "images": stack_images(t2i_images), + }, + "i2i": { + "prompts": i2i_prompts, + "source_images": stack_images(i2i_source_images), + "target_images": stack_images(i2i_target_images), + }, + } + + +def collate_fn_v2t(batch: list[dict[str, Any]]) -> Optional[dict[str, Any]]: + filtered = [sample for sample in batch if sample is not None] + if not filtered: + return None + video_tensors: list[torch.Tensor] = [] + captions: list[Any] = [] + for sample in filtered: + frames = sample.get("video") + if frames is None: + continue + frame_tensor = torch.stack(frames, dim=0) + video_tensors.append(frame_tensor) + captions.append(sample.get("caption")) + if not video_tensors: + return None + return { + "video": torch.stack(video_tensors, dim=0), + "captions": captions, + } + + +def collate_fn_v2s(batch: list[dict[str, Any]]) -> Optional[dict[str, Any]]: + filtered = [sample for sample in batch if sample is not None] + if not filtered: + return None + video_tensors: list[torch.Tensor] = [] + speech_entries: list[Any] = [] + for sample in filtered: + frames = sample.get("video") + speech_value = sample.get("speech") + if frames is None or speech_value is None: + continue + frame_tensor = torch.stack(frames, dim=0) + video_tensors.append(frame_tensor) + speech_entries.append(speech_value) + if not video_tensors: + return None + return { + "video": torch.stack(video_tensors, dim=0), + "speech": speech_entries, + } + +def collate_fn_video_multimodal(batch): + text_videos: list[torch.Tensor] = [] + text_captions: list = [] + speech_videos: list[torch.Tensor] = [] + speech_entries: list[Any] = [] + + for sample in batch: + if sample is None: + continue + text_sample = sample.get("text") + if isinstance(text_sample, dict) and text_sample.get("video") is not None: + frames = text_sample["video"] + frame_tensor = torch.stack(frames, dim=0) + text_videos.append(frame_tensor) + text_captions.append(text_sample["caption"]) + + speech_sample = sample.get("speech") + if isinstance(speech_sample, dict) and speech_sample.get("video") is not None: + frames = speech_sample["video"] + frame_tensor = torch.stack(frames, dim=0) + speech_videos.append(frame_tensor) + speech_entries.append(speech_sample["speech"]) + + output: Dict[str, Any] = {} + if text_videos: + output["text"] = { + "video": torch.stack(text_videos, dim=0), + "captions": text_captions, + } + if speech_videos: + output["speech"] = { + "video": torch.stack(speech_videos, dim=0), + "speech": speech_entries, + } + if not output: + return None + return output + + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + for item in batch: + audio_entry = item['audio_path'] + if isinstance(audio_entry, torch.Tensor): + tokens = audio_entry.cpu() + else: + tokens = vq_model_audio.encode(audio_entry).cpu() + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + +@ torch.no_grad() +def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., + cfg_scale=0., remasking='low_confidence', mask_id=126336, attention_mask=None): + ''' + Args: + model: Mask predictor. + prompt: A tensor of shape (B, L), where B is batch size. + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + ''' + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + batch_size = prompt.shape[0] + x = torch.full((batch_size, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x[:, :prompt.shape[1]] = prompt.clone() + + prompt_index = (x != mask_id) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = model(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = model(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + # print(confidence.shape) + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + return x + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + +def _resolve_mask_schedule(config): + schedule_cfg = getattr(config, "mask_schedule", None) + if isinstance(schedule_cfg, DictConfig): + schedule_name = getattr(schedule_cfg, "schedule", None) + params_cfg = getattr(schedule_cfg, "params", None) + elif isinstance(schedule_cfg, dict): + schedule_name = schedule_cfg.get("schedule") + params_cfg = schedule_cfg.get("params") + else: + schedule_name = None + params_cfg = None + if schedule_name is None: + schedule_name = config.training.get("mask_schedule", "cosine") + params = {} + if params_cfg is not None: + if isinstance(params_cfg, DictConfig): + params = OmegaConf.to_container(params_cfg, resolve=True) or {} + elif isinstance(params_cfg, dict): + params = dict(params_cfg) + else: + params = params_cfg + if not isinstance(params, dict): + params = {} + return get_mask_schedule(schedule_name, **params) +def _tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image: + image = torch.clamp((image_tensor.detach().cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) + array = (image.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) + return Image.fromarray(array) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2I EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +@torch.no_grad() +def evaluate_t2i(model, vq_model_image, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2I Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + prompts_file = "/home/work/AIDAS/MMaDA/validation_prompts/quantative.txt" + if not prompts_file: + logger.warning("No validation prompts file configured. Skipping T2I evaluation.") + return + prompts_path = Path(prompts_file) + if not prompts_path.is_absolute(): + prompts_path = Path.cwd() / prompts_path + if not prompts_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / prompts_file + if alt_path.exists(): + prompts_path = alt_path + try: + with open(prompts_path, "r", encoding="utf-8") as handle: + prompts = [line.strip() for line in handle if line.strip()] + except OSError as exc: + logger.warning(f"Failed to read validation prompts from '{prompts_path}': {exc}. Skipping T2I evaluation.") + return + if not prompts: + logger.warning("Validation prompts file is empty. Skipping T2I evaluation.") + return + max_samples = getattr(config.experiment, "eval_num_t2i_samples", 8) + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 8 + prompts = prompts[:max_samples] + mask_schedule = _resolve_mask_schedule(config) + mask_token_id = unwrapped_model.config.mask_token_id + seq_len = getattr(getattr(config.model, "omada", None), "num_vq_tokens", None) + if seq_len is None: + seq_len = getattr(unwrapped_model.config, "num_vq_tokens", None) + if seq_len is None: + logger.warning("Unable to determine image token sequence length. Skipping T2I evaluation.") + return + seq_len = int(seq_len) + device = accelerator.device + image_tokens = torch.full((len(prompts), seq_len), mask_token_id, dtype=torch.long, device=device) + input_ids, attention_mask = uni_prompting((prompts, image_tokens), 't2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids, uncond_attention_mask = None, None + cfg_scale = 0.0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + use_autocast = accelerator.device.type == "cuda" and accelerator.mixed_precision != "no" + autocast_ctx = torch.autocast("cuda", dtype=weight_dtype) if use_autocast else contextlib.nullcontext() + with autocast_ctx: + gen_token_ids = unwrapped_model.t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=3.5, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=15, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=seq_len, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, min=0, max=unwrapped_model.config.codebook_size - 1) + images = vq_model_image.decode_code(gen_token_ids) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images = images.permute(0, 2, 3, 1).cpu().numpy() * 255.0 + pil_images = [Image.fromarray(img.astype(np.uint8)) for img in images] + wandb_images = [wandb.Image(img, caption=prompt) for img, prompt in zip(pil_images, prompts)] + accelerator.log({"eval/t2i_samples": wandb_images}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ I2I EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +@torch.no_grad() +def evaluate_i2i(model, vq_model_image, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running I2I Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + dataset_cfg = getattr(config.dataset, "params", {}) + resolution = getattr(dataset_cfg, "resolution", 256) + + def _cfg_to_dict(cfg): + if cfg is None: + return None + if isinstance(cfg, dict): + return cfg + if isinstance(cfg, DictConfig): + return OmegaConf.to_container(cfg, resolve=True) + return cfg + + eval_dataset = None + eval_dataset_name = "unknown" + openimage_cfg_raw = None + if isinstance(dataset_cfg, dict): + openimage_cfg_raw = dataset_cfg.get("openimage_i2i") + else: + openimage_cfg_raw = getattr(dataset_cfg, "openimage_i2i", None) + openimage_cfg = _cfg_to_dict(openimage_cfg_raw) + if openimage_cfg: + try: + eval_dataset = OpenImageI2IDataset( + resolution=resolution, + image_root=openimage_cfg.get("image_root"), + sft_jsonl=openimage_cfg.get("sft_jsonl"), + pref_jsonl=openimage_cfg.get("pref_jsonl"), + multi_turn_jsonl=openimage_cfg.get("multi_turn_jsonl"), + prefer_summarized_text=bool(openimage_cfg.get("prefer_summarized_text", True)), + pref_positive_only=bool(openimage_cfg.get("pref_positive_only", True)), + skip_missing=bool(openimage_cfg.get("skip_missing", True)), + max_samples_per_source=openimage_cfg.get("max_samples_per_source"), + max_total_samples=openimage_cfg.get("max_total_samples"), + seed=openimage_cfg.get("seed"), + ) + eval_dataset_name = "OpenImage I2I" + logger.info("Using OpenImage I2I dataset for evaluation (samples=%d).", len(eval_dataset)) + except Exception as exc: + logger.warning("Failed to build OpenImage I2I eval dataset (%s); falling back to HQ-Edit.", exc) + eval_dataset = None + if eval_dataset is None: + eval_dataset = HQEditX2IDataset(split='train', resolution=resolution) + eval_dataset_name = "HQ-Edit" + logger.info("Using HQ-Edit dataset for I2I evaluation.") + + if len(eval_dataset) == 0: + logger.warning("%s evaluation split is empty. Skipping I2I evaluation.", eval_dataset_name) + return + max_samples = getattr(config.experiment, "eval_num_i2i_samples", 8) + + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 8 + num_samples = min(max_samples, len(eval_dataset)) + if len(eval_dataset) <= num_samples: + sample_indices = list(range(len(eval_dataset))) + else: + sample_indices = random.sample(range(len(eval_dataset)), num_samples) + samples = [eval_dataset[i] for i in sample_indices] + prompts = [] + original_tensors = [] + target_tensors = [] + for sample in samples: + prompts.append(sample.get("edit_prompt") or sample.get("output_prompt") or "") + original_tensors.append(sample["input_image"]) + target_tensors.append(sample["output_image"]) + original_images = torch.stack(original_tensors, dim=0).to(accelerator.device) + original_tokens = vq_model_image.get_code(original_images) + len(uni_prompting.text_tokenizer) + seq_len = original_tokens.shape[-1] + mask_token_id = unwrapped_model.config.mask_token_id + placeholder = torch.full((num_samples, seq_len), mask_token_id, dtype=torch.long, device=accelerator.device) + input_ids, attention_mask = uni_prompting((prompts, original_tokens, placeholder), 'i2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting( + ([''] * num_samples, original_tokens, placeholder), 'i2i_gen' + ) + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids, uncond_attention_mask = None, None + cfg_scale = 0.0 + mask_schedule = _resolve_mask_schedule(config) + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + use_autocast = accelerator.device.type == "cuda" and accelerator.mixed_precision != "no" + autocast_ctx = torch.autocast("cuda", dtype=weight_dtype) if use_autocast else contextlib.nullcontext() + with autocast_ctx: + gen_token_ids = unwrapped_model.i2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=3.5, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=15, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + seq_len=seq_len, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, min=0, max=unwrapped_model.config.codebook_size - 1) + generated_images = vq_model_image.decode_code(gen_token_ids) + generated_images = torch.clamp((generated_images + 1.0) / 2.0, min=0.0, max=1.0) + gen_images_pil = [Image.fromarray((img.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)) for img in generated_images] + source_pil = [_tensor_to_pil(tensor) for tensor in original_tensors] + target_pil = [_tensor_to_pil(tensor) for tensor in target_tensors] + log_resolution = getattr(config.experiment, "eval_image_log_resolution", 512) + wandb_images = [] + for prompt, src, pred, tgt in zip(prompts, source_pil, gen_images_pil, target_pil): + composite = Image.new('RGB', (log_resolution * 3, log_resolution)) + src_resized = src.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + pred_resized = pred.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + tgt_resized = tgt.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + composite.paste(src_resized, (0, 0)) + composite.paste(pred_resized, (log_resolution, 0)) + composite.paste(tgt_resized, (log_resolution * 2, 0)) + wandb_images.append(wandb.Image(composite, caption=f"Prompt: {prompt}")) + accelerator.log({"eval/i2i_samples": wandb_images}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running S2S Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + dataset_cfg = getattr(config.dataset, "params", {}) + if isinstance(dataset_cfg, DictConfig): + dataset_cfg = dataset_cfg + s2s_eval_dir = getattr(dataset_cfg, "s2s_eval_dir", "MMaDA/validation_prompts/s2s") + + base_path = Path(s2s_eval_dir) + if not base_path.is_absolute(): + base_path = Path.cwd() / base_path + if not base_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / s2s_eval_dir + if alt_path.exists(): + base_path = alt_path + + if not base_path.exists(): + logger.warning(f"S2S evaluation directory '{s2s_eval_dir}' not found. Skipping S2S evaluation.") + return + + audio_exts = {".wav", ".flac", ".mp3", ".ogg", ".m4a"} + wav_files = sorted(p for p in base_path.iterdir() if p.is_file() and p.suffix.lower() in audio_exts) + if not wav_files: + logger.warning(f"No audio files found in '{base_path}'. Skipping S2S evaluation.") + return + + condition = getattr(dataset_cfg, "s2s_eval_condition", "gender-female_emotion-neutral_speed-normal_pitch-normal") + mask_token_id = unwrapped_model.config.mask_token_id + codebook_size = int(getattr(config.model.omada, "codebook_size", 8192)) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_size + audio_codebook_size = 4096 + if audio_codebook_size <= 0: + logger.warning("Computed audio codebook size is non-positive. Skipping S2S evaluation.") + return + + a_tokens = uni_prompting.text_tokenizer("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", return_tensors="pt").input_ids + soa_len = uni_prompting.sptids_dict['<|soa|>'].numel() + eoa_len = uni_prompting.sptids_dict['<|eoa|>'].numel() + asst_header_len = a_tokens.shape[1] + max_audio_len = getattr(uni_prompting, "max_audio_len", 256) + max_generatable = 256 + + offset = len(uni_prompting.text_tokenizer) + codebook_size + device = accelerator.device + + output_root = Path(config.experiment.output_dir) / "eval_s2s" / f"step_{global_step}" + output_root.mkdir(parents=True, exist_ok=True) + + table = wandb.Table(columns=["Audio ID", "Source Audio", "Generated Audio", "Token Count"]) + + for audio_path in wav_files: + try: + user_tokens = vq_model_audio.encode(str(audio_path)) + except Exception as exc: + logger.error(f"Failed to encode '{audio_path}': {exc}") + continue + + if not isinstance(user_tokens, torch.Tensor): + user_tokens = torch.tensor(user_tokens) + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + + user_tokens = user_tokens.to(device=device, dtype=torch.long) + if user_tokens.numel() == 0: + logger.warning(f"Encoded audio from '{audio_path}' produced no tokens. Skipping sample.") + continue + + # Use a fixed assistant placeholder length so generation is not limited by user input duration. + assistant_len = max_generatable + if assistant_len <= 0: + logger.warning(f"Assistant placeholder length for '{audio_path}' is non-positive. Skipping sample.") + continue + + user_shifted = user_tokens + offset + assistant_placeholder = torch.full( + (1, assistant_len), + mask_token_id, + dtype=torch.long, + device=device, + ) + + input_ids, attention_mask = uni_prompting( + ([user_shifted], [assistant_placeholder]), + 's2s_gen' + ) + + try: + generated_sequences = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=256, + steps=128, + block_length=128, + temperature=config.training.get("s2s_generation_temperature", 1.0), + cfg_scale=config.training.get("s2s_guidance_scale", 3.0), + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=audio_codebook_size, + ) + except Exception as exc: + logger.error(f"Generation failed for '{audio_path}': {exc}") + continue + + if not generated_sequences: + logger.warning(f"No tokens generated for '{audio_path}'. Skipping sample.") + continue + + gen_tokens = generated_sequences[0] + if isinstance(gen_tokens, torch.Tensor): + gen_tokens = gen_tokens.detach().cpu() + token_list = gen_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list empty for '{audio_path}'. Skipping sample.") + continue + + speech_unit_str = "".join([f"<|speech_{int(token)}|>" for token in token_list]) + output_path = output_root / f"{audio_path.stem}_reply.wav" + + try: + vq_model_audio.decode(speech_unit_str, condition=condition, output_wav_file=str(output_path)) + except Exception as exc: + logger.error(f"Decoding failed for '{audio_path}': {exc}") + continue + + table.add_data( + audio_path.name, + wandb.Audio(str(audio_path), caption="source"), + wandb.Audio(str(output_path), caption="generated"), + len(token_list), + ) + + row_count = getattr(table, "num_rows", None) + if row_count is None: + table_data = getattr(table, "data", None) + row_count = len(table_data) if table_data is not None else 0 + + if row_count > 0: + accelerator.log({"eval/s2s_samples": table}, step=global_step) + else: + logger.warning("S2S evaluation produced no loggable samples.") + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ TEXT EVALUATION LOGIC ++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_text(model, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running Text Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + dataset_cfg = getattr(config.dataset, "params", {}) + prompts_file = getattr(dataset_cfg, "text_eval_prompts_file", "MMaDA/validation_prompts/math.txt") + + prompts_path = Path(prompts_file) + if not prompts_path.is_absolute(): + prompts_path = Path.cwd() / prompts_path + if not prompts_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / prompts_file + if alt_path.exists(): + prompts_path = alt_path + + if not prompts_path.exists(): + logger.warning(f"Text evaluation prompts file '{prompts_file}' not found. Skipping text evaluation.") + return + + try: + with open(prompts_path, "r", encoding="utf-8") as handle: + raw_prompts = [line.strip() for line in handle if line.strip()] + except OSError as exc: + logger.warning(f"Failed to read text evaluation prompts from '{prompts_path}': {exc}. Skipping text evaluation.") + return + + if not raw_prompts: + logger.warning("Text evaluation prompt list is empty. Skipping text evaluation.") + return + + max_samples = getattr(config.experiment, "eval_num_text_samples", 4) + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 4 + questions = raw_prompts[:max_samples] + + chat_prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for question in questions + ] + + tokenizer = uni_prompting.text_tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + answers: list[str] = [] + + for chat_prompt in chat_prompts: + tokens = tokenizer( + chat_prompt, + return_tensors="pt", + padding=True, + truncation=True, + ) + + input_ids = tokens["input_ids"].to(accelerator.device) + out = generate(unwrapped_model, input_ids, steps=128, gen_length=128, block_length=128, temperature=1, cfg_scale=0., remasking='low_confidence') + answer = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True) + + answers.append(answer) + + table = wandb.Table(columns=["Index", "Question", "Answer"]) + for idx, (question, answer) in enumerate(zip(questions, answers)): + table.add_data(idx, question, answer) + + accelerator.log({"eval/text_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(32)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation"): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + eos_token = uni_prompting.text_tokenizer.eos_token + eos_marker = eos_token if eos_token is not None else "" + for i in range(len(decoded_texts)): + full_text = decoded_texts[i] + eos_idx = full_text.find(eos_marker) + cleaned_text = full_text[:eos_idx] if eos_idx != -1 else full_text + cleaned_text = cleaned_text.replace(eos_marker, "").strip() + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": cleaned_text, + }) + + if not local_results: + logger.warning("S2T evaluation produced no results.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["decoded_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + + samples_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for idx, res in enumerate(local_results): + sample_id = res.get("sample_id", idx) + samples_table.add_data(sample_id, res["gt_text"], res["decoded_text"]) + + accelerator.log({"eval/s2t_samples": samples_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for T2S: user prompt + text + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 384 - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=5.0, + temperature=1.0, + timesteps=50, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=383, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + if not local_results: + logger.warning("Skipping T2S evaluation logging because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """Text-to-speech evaluation using the MMU-style block refinement decoder.""" + + if not accelerator.is_main_process: + return + + logger.info("***** Running T2S Evaluation (MMU-style decoder) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as exc: + logger.error(f"Failed to load T2S dataset or Whisper model for MMU-style eval: {exc}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_mmu") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + + codebook_size = config.model.omada.codebook_size + speech_vocab_size = 4096 + + for batch in tqdm(t2s_dataloader, desc="T2S MMU Eval"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + + batch_size = len(prompts) + speech_token_length = 384 - 1 + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + output_ids = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=speech_token_length, + steps=384 - 1, + block_length=384 - 1, + temperature=1.0, + cfg_scale=3.5, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=speech_vocab_size, + ) + + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + if isinstance(gen_speech_tokens, torch.Tensor): + gen_speech_tokens = gen_speech_tokens.detach().cpu() + + token_list = gen_speech_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]} (MMU eval). Skipping.") + continue + + speech_unit_str = " ".join(map(str, token_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}_mmu.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + try: + vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + except Exception as exc: + logger.error(f"Decoding failed for sample {sample_ids[i]} (MMU eval): {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": output_wav_path, + }) + + if not local_results: + logger.warning("Skipping T2S MMU-style evaluation because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S (MMU-style) Final WER: {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_mmu_like_wer": wer, + "eval/t2s_mmu_like_word_errors": errors, + "eval/t2s_mmu_like_total_words": words, + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_mmu_like_samples": results_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2s(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running V2S Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + except Exception as exc: + logger.error(f"Failed to load Whisper model for V2S eval: {exc}") + return + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2S eval root '{video_root}' not found. Skipping V2S evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2S evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Whisper Transcript", "Generated Audio"]) + + device = unwrapped_model.device + mask_token_id = unwrapped_model.config.mask_token_id + eoa_token_id = int(uni_prompting.sptids_dict['<|eoa|>'][0].item()) + audio_codebook_size = 4096 + max_audio_tokens = int(getattr(uni_prompting, "max_audio_len_short", config.dataset.preprocessing.max_aud_length_short)) + max_new_tokens = max(1, max_audio_tokens - 1) + block_length = 128 if max_new_tokens % 128 == 0 else max_new_tokens + + output_dir = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio_v2s", f"step_{global_step}") + os.makedirs(output_dir, exist_ok=True) + + for file_name in tqdm(file_list[:], desc="V2S Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for idx in range(total_frames): + ret, frame = cap.read() + if idx in indices: + if not ret: + continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: + logger.warning(f"Skipping {file_name}: insufficient frames.") + continue + + video_tensor = torch.stack(frames).to(device) + try: + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + except Exception as exc: + logger.error(f"Failed to encode video {file_name}: {exc}") + continue + video_tokens = video_tokens.view(1, -1) + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + audio_placeholder = torch.full( + (1, max_new_tokens), + mask_token_id, + dtype=torch.long, + device=device, + ) + try: + seq_ids, attn_mask = uni_prompting( + (video_tokens, [prompt_text], [audio_placeholder]), + 'v2s_gen' + ) + except Exception as exc: + logger.error(f"Prompt construction failed for {file_name}: {exc}") + continue + + input_ids = seq_ids.to(device) + attention_mask = attn_mask.to(device) + + try: + generated_list = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=256, + steps=128, + block_length=128, + temperature=1.0, + cfg_scale=3.0, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=config.model.omada.codebook_size, + audio_codebook_size=audio_codebook_size, + ) + except Exception as exc: + logger.error(f"Generation failed for {file_name}: {exc}") + continue + + generated_tokens = generated_list[0] + if isinstance(generated_tokens, torch.Tensor): + generated_tokens = generated_tokens.detach().cpu() + + collected = [] + for token in generated_tokens.tolist(): + if token == eoa_token_id or token >= audio_codebook_size: + break + if token >= 0: + collected.append(token) + + if not collected: + logger.warning(f"No valid audio tokens generated for {file_name}.") + continue + + speech_unit_for_decode = "".join(f"<|speech_{tok}|>" for tok in collected) + output_wav_path = os.path.join(output_dir, f"{Path(file_name).stem}_v2s.wav") + try: + vq_model_audio.decode( + speech_unit_for_decode, + condition='gender-female_emotion-neutral_speed-normal_pitch-normal', + output_wav_file=output_wav_path + ) + except Exception as exc: + logger.error(f"Decoding failed for {file_name}: {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + results_table.add_data( + file_name, + question, + whisper_text, + wandb.Audio(output_wav_path, caption=whisper_text) + ) + + if len(results_table.data) == 0: + logger.warning("V2S evaluation produced no samples to log.") + return + + accelerator.log({"eval/v2s_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + if accelerator.is_main_process: + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + + if accelerator.is_main_process: + evaluate_v2s(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_text(model, uni_prompting, config, accelerator, global_step) + evaluate_t2i(model, vq_model_image, uni_prompting, config, accelerator, global_step) + evaluate_i2i(model, vq_model_image, uni_prompting, config, accelerator, global_step) + evaluate_s2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + _configure_multiprocessing() + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = ( + config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s + + config.training.batch_size_s2s + ) - 1 # -1 since t2s/ s2t choice + + total_batch_size = ( + total_batch_size_per_gpu + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, max_audio_len_short=config.dataset.preprocessing.max_aud_length_short, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|i2i|>", "<|v2t|>", "<|v2s|>", "<|s2t|>", "<|t2s|>", "<|s2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + + # speech_vocab_start = int(config.model.omada.llm_vocab_size) + int(config.model.omada.codebook_size) # 126464 + 8192 = 134656 + # audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) # 4096 + + logger.info(f"SPEECHVOCABSTART: {speech_vocab_start}") + logger.info(f"int(config.model.omada.new_vocab_size): {int(config.model.omada.new_vocab_size)}") + logger.info(f"AUDIOCODEBOOKSIZE: {audio_codebook_size}") + + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + # Speech-token caching configuration + speech_cache_cfg = getattr(config.dataset, "speech_token_cache", {}) + if not isinstance(speech_cache_cfg, dict): + speech_cache_cfg = OmegaConf.to_container(speech_cache_cfg, resolve=True) + speech_cache_cfg = speech_cache_cfg or {} + + speech_cache_enabled = bool(speech_cache_cfg.get("enable", False)) + speech_cache_dir: Optional[Path] + if speech_cache_enabled: + cache_root = speech_cache_cfg.get("root", "cache/speech_tokens") + speech_cache_dir = Path(cache_root) + try: + speech_cache_dir.mkdir(parents=True, exist_ok=True) + except OSError: + speech_cache_dir = None + speech_cache_enabled = False + logger.warning("Failed to create speech cache directory at %s; disabling cache.", cache_root) + else: + speech_cache_dir = None + + speech_cache_max_items = int(speech_cache_cfg.get("max_items_in_memory", 4096)) + audio_token_cache_mem: Dict[str, torch.Tensor] = {} + + def _get_audio_cache_path(audio_path: Union[str, Path]) -> Optional[Path]: + if not isinstance(audio_path, (str, os.PathLike)): + return None + if not speech_cache_enabled or speech_cache_dir is None: + return None + key = os.path.abspath(str(audio_path)) + digest = hashlib.sha1(key.encode("utf-8")).hexdigest() + subdir = speech_cache_dir / digest[:2] / digest[2:4] + return subdir / f"{digest}.pt" + + def _load_cached_audio_tokens(audio_path: Union[str, Path]) -> Optional[torch.Tensor]: + if not isinstance(audio_path, (str, os.PathLike)): + return None + cache_key = os.path.abspath(str(audio_path)) + cached = audio_token_cache_mem.get(cache_key) + if cached is not None: + return cached.clone() + + cache_path = _get_audio_cache_path(audio_path) + if cache_path is None or not cache_path.exists(): + return None + try: + tokens = torch.load(cache_path, map_location="cpu") + if isinstance(tokens, torch.Tensor): + if len(audio_token_cache_mem) < speech_cache_max_items: + audio_token_cache_mem[cache_key] = tokens + return tokens.clone() + except Exception as exc: + logger.warning("Failed to load cached speech tokens from %s (%s); ignoring cache.", cache_path, exc) + return None + + def _store_cached_audio_tokens(audio_path: Union[str, Path], tokens: torch.Tensor) -> None: + if not isinstance(audio_path, (str, os.PathLike)): + return + cache_path = _get_audio_cache_path(audio_path) + if cache_path is None: + return + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") + torch.save(tokens.cpu(), tmp_path) + os.replace(tmp_path, cache_path) + except Exception as exc: + logger.warning("Failed to write speech token cache to %s (%s).", cache_path, exc) + return + cache_key = os.path.abspath(str(audio_path)) + if len(audio_token_cache_mem) < speech_cache_max_items: + audio_token_cache_mem[cache_key] = tokens.cpu() + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16, config='/home/work/AIDAS/ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model/config.json').to(accelerator.device) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + def build_distributed_sampler(dataset, *, shuffle=True, drop_last=True): + """Create a DistributedSampler only when running with multiple processes.""" + if dataset is None or accelerator.num_processes <= 1: + return None + return DistributedSampler( + dataset, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=shuffle, + drop_last=drop_last, + ) + + batch_size_t2i_cfg = config.training.batch_size_t2i + batch_size_lm_cfg = config.training.batch_size_lm + batch_size_mmu_cfg = config.training.batch_size_mmu + batch_size_t2s_cfg = config.training.batch_size_t2s + batch_size_s2t_cfg = config.training.batch_size_s2t + batch_size_v2t_cfg = config.training.batch_size_v2t + batch_size_s2s_cfg = config.training.batch_size_s2s + batch_size_v2s_cfg = batch_size_v2t_cfg + + total_batch_size = ( + total_batch_size_per_gpu + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + pin_memory = bool(getattr(dataset_config, "pin_memory", False)) + persistent_workers = bool(getattr(dataset_config, "persistent_workers", False)) + dataloader_timeout = int(getattr(dataset_config, "dataloader_timeout", 120)) + + if persistent_workers and dataloader_timeout > 0: + logger.warning( + "persistent_workers=True requires dataloader_timeout=0; overriding timeout=%s", + dataloader_timeout, + ) + dataloader_timeout = 0 + + if ( + not persistent_workers + and int(getattr(dataset_config, "num_workers", 0)) > 0 + and str(config.dataset.combined_loader_mode) == "max_size_cycle" + ): + logger.warning( + "Using combined_loader_mode='max_size_cycle' with num_workers>0 and " + "persistent_workers=False can exhaust OS semaphores when loaders cycle. " + "Set dataset.params.persistent_workers=True to keep worker processes alive." + ) + + # Text-to-image / Image-to-image datasets + logger.info("Loading Text-to-image / Image-to-image datasets") + dataset_t2i = None + dataset_i2i = None + train_dataloader_t2i = None + train_dataloader_i2i = None + sampler_t2i: Optional[DistributedSampler] = None # type: ignore[assignment] + sampler_i2i: Optional[DistributedSampler] = None # type: ignore[assignment] + if batch_size_t2i_cfg > 0: + raw_t2i_choice = dataset_config.get("t2i_dataset", "hqedit") + if isinstance(raw_t2i_choice, str): + split_tokens = [token.strip() for token in raw_t2i_choice.replace(",", "+").split("+")] + dataset_choices = [token for token in split_tokens if token] + else: + dataset_choices = [str(token).strip() for token in raw_t2i_choice if str(token).strip()] + + if not dataset_choices: + raise ValueError("t2i_dataset configuration produced no valid dataset names.") + + t2i_datasets: list[Dataset] = [] + i2i_datasets: list[Dataset] = [] + t2i_source_names: list[str] = [] + i2i_source_names: list[str] = [] + for choice in dataset_choices: + choice_lower = choice.lower() + if choice_lower in {"hqedit", "hq-edit", "hq_edit"}: + i2i_datasets.append( + HQEditX2IDataset( + split=dataset_config.get("hqedit_split", "train"), + resolution=dataset_config.resolution, + ) + ) + logger.info("Using HQ-Edit dataset for T2I/i2i branch (%s split)", dataset_config.get("hqedit_split", "train")) + i2i_source_names.append(choice) + elif choice_lower in {"text2image2m", "text-to-image-2m", "text_to_image_2m"}: + t2i_datasets.append( + TextToImage2MDataset( + split=dataset_config.get("t2i_split", "train"), + resolution=dataset_config.resolution, + dataset_name=dataset_config.get("t2i_dataset_name", "jackyhate/text-to-image-2M"), + cache_dir=dataset_config.get("t2i_cache_dir", None), + local_files_only=bool(dataset_config.get("t2i_local_files_only", False)), + ) + ) + logger.info( + "Using text-to-image-2M dataset for T2I branch (split=%s, dataset=%s)", + dataset_config.get("t2i_split", "train"), + dataset_config.get("t2i_dataset_name", "jackyhate/text-to-image-2M"), + ) + t2i_source_names.append(choice) + elif choice_lower in {"openimage", "openimage_i2i", "openimage-edit", "openimage_local"}: + raw_openimage_cfg = getattr(dataset_config, "openimage_i2i", None) + if raw_openimage_cfg is None: + raise ValueError("dataset.params.openimage_i2i must be configured to use the OpenImage dataset.") + if not isinstance(raw_openimage_cfg, dict): + openimage_cfg = OmegaConf.to_container(raw_openimage_cfg, resolve=True) or {} + else: + openimage_cfg = raw_openimage_cfg + + i2i_datasets.append( + OpenImageI2IDataset( + resolution=dataset_config.resolution, + image_root=openimage_cfg.get("image_root"), + sft_jsonl=openimage_cfg.get("sft_jsonl"), + pref_jsonl=openimage_cfg.get("pref_jsonl"), + multi_turn_jsonl=openimage_cfg.get("multi_turn_jsonl"), + prefer_summarized_text=bool(openimage_cfg.get("prefer_summarized_text", True)), + pref_positive_only=bool(openimage_cfg.get("pref_positive_only", True)), + skip_missing=bool(openimage_cfg.get("skip_missing", True)), + max_samples_per_source=openimage_cfg.get("max_samples_per_source"), + max_total_samples=openimage_cfg.get("max_total_samples"), + seed=openimage_cfg.get("seed"), + ) + ) + cfg_paths = [ + openimage_cfg.get("sft_jsonl"), + openimage_cfg.get("pref_jsonl"), + openimage_cfg.get("multi_turn_jsonl"), + ] + cfg_paths = [str(path) for path in cfg_paths if path] + logger.info( + "Using OpenImage local edit dataset for i2i (jsonl=%s)", + ", ".join(cfg_paths) if cfg_paths else "n/a", + ) + i2i_source_names.append(choice) + else: + raise ValueError(f"Unsupported t2i_dataset '{choice}'") + + if t2i_datasets: + dataset_t2i = ( + t2i_datasets[0] + if len(t2i_datasets) == 1 + else CombinedX2IDataset(t2i_datasets) + ) + logger.info( + "T2I dataloading sources: %s", + ", ".join(t2i_source_names) if t2i_source_names else "n/a", + ) + + sampler_t2i = build_distributed_sampler( + dataset_t2i, + shuffle=True, + drop_last=True, + ) + + train_dataloader_t2i = DataLoader( + dataset_t2i, + batch_size=batch_size_t2i_cfg, + sampler=sampler_t2i, + shuffle=sampler_t2i is None, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_x2i, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if i2i_datasets: + dataset_i2i = ( + i2i_datasets[0] + if len(i2i_datasets) == 1 + else CombinedX2IDataset(i2i_datasets) + ) + logger.info( + "I2I dataloading sources: %s", + ", ".join(i2i_source_names) if i2i_source_names else "n/a", + ) + + sampler_i2i = build_distributed_sampler( + dataset_i2i, + shuffle=True, + drop_last=True, + ) + + train_dataloader_i2i = DataLoader( + dataset_i2i, + batch_size=batch_size_t2i_cfg, + sampler=sampler_i2i, + shuffle=sampler_i2i is None, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_x2i, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Language modeling dataset (HF instruction mixture) + logger.info("Loading LM dataset") + dataset_lm = None + train_dataloader_lm = None + if batch_size_lm_cfg > 0: + instruction_cfg = getattr(dataset_config, "hf_instruction_lm", {}) + if not isinstance(instruction_cfg, dict): + instruction_cfg = OmegaConf.to_container(instruction_cfg, resolve=True) + instruction_cfg = instruction_cfg or {} + + seed_lm = instruction_cfg.get("seed") + if seed_lm is None: + seed_lm = getattr(config.training, "seed", 42) or 42 + + dataset_lm = HFInstructionTextDataset( + split=instruction_cfg.get("split", "train"), + max_samples_per_source=instruction_cfg.get("max_samples_per_source"), + max_total_samples=instruction_cfg.get("max_total_samples"), + seed=int(seed_lm), + ) + + sampler_lm = build_distributed_sampler( + dataset_lm, + shuffle=True, + drop_last=True, + ) + + train_dataloader_lm = DataLoader( + dataset_lm, + batch_size=batch_size_lm_cfg, + sampler=sampler_lm, + shuffle=sampler_lm is None, + collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Video Dataset + logger.info("Loading Video dataset") + dataset_v2t = None + dataset_v2s = None + train_dataloader_v2t = None + train_dataloader_v2s = None + sampler_v2t = None + sampler_v2s = None + speech_cfg = getattr(dataset_config, "video_speech_dataset", {}) + if not isinstance(speech_cfg, dict): + speech_cfg = OmegaConf.to_container(speech_cfg, resolve=True) + speech_cfg = speech_cfg or {} + + if batch_size_v2t_cfg > 0: + dataset_v2t = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + dataset_name=speech_cfg.get("llavavid_dataset_name", "llavavid"), + llavavid_path=speech_cfg.get("llavavid_path", "lmms-lab/LLaVA-Video-178K"), + num_frames=8, + llavavid_local_files_only=bool(speech_cfg.get("llavavid_local_files_only", False)), + llavavid_skip_configs=speech_cfg.get("llavavid_skip_configs"), + llavavid_skip_video_patterns=speech_cfg.get("llavavid_skip_video_patterns"), + ) + + sampler_v2t = build_distributed_sampler( + dataset_v2t, + shuffle=True, + drop_last=True, + ) + + train_dataloader_v2t = DataLoader( + dataset_v2t, + batch_size=batch_size_v2t_cfg, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_v2t, + sampler=sampler_v2t, + shuffle=sampler_v2t is None, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if batch_size_v2s_cfg > 0: + dataset_v2s = VideoSpeechDataset( + transform=image_transform, + resolution=preproc_config.resolution, + num_frames=speech_cfg.get("num_frames_speech", 4), + video_root=speech_cfg.get( + "video_root", "/home/work/AIDAS/data/video/openvid1m/video/video" + ), + audio_root=speech_cfg.get( + "audio_root", "/home/work/AIDAS/data/video-speech" + ), + speech_dir_name=speech_cfg.get("speech_dir_name", "openvid-speech-trunc"), + index_path=speech_cfg.get( + "index_path", "/home/work/AIDAS/data/video-speech/openvid-speech.csv" + ), + sample_method=speech_cfg.get("sample_method", "uniform"), + precomputed_tokens_root=( + speech_cfg.get("precomputed_tokens_root") + if speech_cfg.get("use_precomputed_tokens", False) + else None + ), + ) + + sampler_v2s = build_distributed_sampler( + dataset_v2s, + shuffle=True, + drop_last=True, + ) + + train_dataloader_v2s = DataLoader( + dataset_v2s, + batch_size=batch_size_v2s_cfg, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_v2s, + sampler=sampler_v2s, + shuffle=sampler_v2s is None, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + # Speech Dataset + logger.info("Loading Speech dataset") + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + # Speech-to-Speech Dataset (EMOVA + Instruct S2S) + dataset_s2s = None + sampler_s2s = None + train_dataloader_s2s = None + if config.training.batch_size_s2s > 0: + dataset_s2s = Speech2SpeechDataset(dataset_config.get("speech2speech", [])) + + # Multi-image interleaved dataset (MMU-style) + logger.info("Loading MMU dataset") + dataset_mmu = None + sampler_mmu = None + train_dataloader_mmu = None + if config.training.batch_size_mmu > 0: + mmu_params = dataset_config.get("mmu_interleaved", {}) + if mmu_params is None: + mmu_kwargs = {} + elif isinstance(mmu_params, dict): + mmu_kwargs = mmu_params + else: + mmu_kwargs = OmegaConf.to_container(mmu_params, resolve=True) + dataset_mmu = TextImageInterleavedDataset(**mmu_kwargs) + + logger.info("Dataset Prepared.") + + # Use distinct DistributedSamplers for each speech dataloader to avoid iterator interference + if accelerator.num_processes > 1: + sampler_s2t = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + sampler_t2s = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + if dataset_s2s is not None: + sampler_s2s = DistributedSampler( + dataset_s2s, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + if dataset_mmu is not None: + sampler_mmu = DistributedSampler( + dataset_mmu, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + else: + sampler_s2t = None + sampler_t2s = None + sampler_s2s = None + sampler_mmu = None + + train_dataloader_s2t = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_s2t, + shuffle=False, + sampler=sampler_s2t, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + train_dataloader_t2s = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_t2s, + shuffle=False, + sampler=sampler_t2s, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if dataset_s2s is not None: + train_dataloader_s2s = DataLoader( + dataset_s2s, + batch_size=batch_size_s2s_cfg, + shuffle=False, + sampler=sampler_s2s, + collate_fn=s2s_collate_fn, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + if dataset_mmu is not None: + train_dataloader_mmu = DataLoader( + dataset_mmu, + batch_size=config.training.batch_size_mmu, + shuffle=False, + sampler=sampler_mmu, + collate_fn=collate_fn_mmu_mult, + num_workers=config.dataset.params.num_workers, + drop_last=True, + pin_memory=pin_memory, + timeout=dataloader_timeout, + persistent_workers=persistent_workers, + ) + + def _num_steps(dataset_obj, batch_size_cfg): + if dataset_obj is None or batch_size_cfg <= 0: + return 0 + total_bs = batch_size_cfg * accelerator.num_processes * config.training.gradient_accumulation_steps + if total_bs <= 0: + return 0 + length = len(dataset_obj) + if length == 0: + return 0 + return math.ceil(length / total_bs) + + num_update_steps_per_epoch_t2i = _num_steps(dataset_t2i, config.training.batch_size_t2i) + num_update_steps_per_epoch_i2i = _num_steps(dataset_i2i, config.training.batch_size_t2i) + num_update_steps_per_epoch_lm = _num_steps(dataset_lm, config.training.batch_size_lm) + num_update_steps_per_epoch_s2t = _num_steps(dataset_sm, config.training.batch_size_s2t) + num_update_steps_per_epoch_t2s = _num_steps(dataset_sm, config.training.batch_size_t2s) + num_update_steps_per_epoch_s2s = _num_steps(dataset_s2s, batch_size_s2s_cfg) + num_update_steps_per_epoch_v2t = _num_steps(dataset_v2t, batch_size_v2t_cfg) + num_update_steps_per_epoch_v2s = _num_steps(dataset_v2s, batch_size_v2s_cfg) + num_update_steps_per_epoch_mmu = _num_steps(dataset_mmu, config.training.batch_size_mmu) + + # Calculate num_train_epochs + num_update_steps_per_epoch = max( + num_update_steps_per_epoch_t2i, + num_update_steps_per_epoch_lm, + num_update_steps_per_epoch_s2t, + num_update_steps_per_epoch_t2s, + num_update_steps_per_epoch_v2t, + num_update_steps_per_epoch_v2s, + num_update_steps_per_epoch_s2s, + num_update_steps_per_epoch_mmu, + num_update_steps_per_epoch_i2i, + ) + + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + logger.info(f"len of T2I: {len(dataset_t2i) if dataset_t2i is not None else 0}") + logger.info(f"len of I2I: {len(dataset_i2i) if dataset_i2i is not None else 0}") + logger.info(f"len of LM: {len(dataset_lm)}") + logger.info(f"len of Speech: {len(dataset_sm)}") + logger.info(f"len of Video Caption: {len(dataset_v2t) if dataset_v2t is not None else 0}") + logger.info(f"len of Video Speech: {len(dataset_v2s) if dataset_v2s is not None else 0}") + logger.info(f"len of S2S: {len(dataset_s2s)}") + logger.info(f"len of MMU: {len(dataset_mmu)}") + + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def _maybe_trim_audio_file(audio_path: Union[str, os.PathLike], max_duration: float) -> tuple[Union[str, os.PathLike], Optional[str]]: + """Return a path to an audio file trimmed to max_duration seconds. + + If trimming succeeds, returns (trimmed_path, temp_path) where trimmed_path is the + file to use for encoding and temp_path should be deleted afterwards. If trimming + fails, returns (audio_path, None). + """ + if max_duration <= 0: + return audio_path, None + trim_timeout = float(getattr(config.dataset.preprocessing, "audio_trim_timeout_sec", 30.0)) + try: + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp_path = tmp.name + tmp.close() + cmd = [ + "ffmpeg", + "-y", + "-hide_banner", + "-loglevel", + "error", + "-i", + str(audio_path), + "-t", + str(max_duration), + "-c", + "copy", + tmp_path, + ] + subprocess.run(cmd, check=True, timeout=trim_timeout) + return tmp_path, tmp_path + except Exception as exc: + warnings.warn(f"Failed to trim audio {audio_path} to {max_duration}s: {exc}") + try: + if 'tmp_path' in locals() and os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError: + pass + return audio_path, None + + def _format_path_for_log(path: Union[str, os.PathLike, torch.Tensor, None]) -> str: + if isinstance(path, (str, os.PathLike)): + try: + return os.fspath(path) + except TypeError: + return str(path) + if isinstance(path, torch.Tensor): + return f"" + if isinstance(path, np.ndarray): + return f"" + if isinstance(path, Sequence) and not isinstance(path, (str, bytes, os.PathLike)): + try: + return f"" + except Exception: + return "" + return repr(path) + + def safe_audio_encode(audio_path: Union[str, torch.Tensor, np.ndarray, Sequence[int]], flow_name: str): + if isinstance(audio_path, torch.Tensor): + return audio_path.cpu().clone(), None + if isinstance(audio_path, np.ndarray): + try: + tensor = torch.from_numpy(audio_path).to(dtype=torch.long) + except Exception as exc: + raise RuntimeError(f"Failed to convert numpy audio tokens to tensor for flow '{flow_name}': {exc}") from exc + return tensor, None + if isinstance(audio_path, Sequence) and not isinstance(audio_path, (str, bytes, os.PathLike)): + try: + tensor = torch.as_tensor(audio_path, dtype=torch.long) + except Exception as exc: + raise RuntimeError(f"Failed to convert cached audio tokens to tensor for flow '{flow_name}': {exc}") from exc + return tensor, None + path_repr = _format_path_for_log(audio_path) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode request: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + + max_retries = int(getattr(config.dataset.preprocessing, "audio_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "audio_encode_retry_backoff_sec", 0.5)) + duration_limit = float(getattr(config.dataset.preprocessing, "max_audio_duration_sec", 15.0)) + + cached = _load_cached_audio_tokens(audio_path) + if cached is not None: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode hit cache: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + return cached, None + + for attempt in range(1, max_retries + 1): + trimmed_path: Union[str, os.PathLike] = audio_path + temp_path: Optional[str] = None + try: + if isinstance(audio_path, (str, os.PathLike)): + trimmed_path, temp_path = _maybe_trim_audio_file(audio_path, duration_limit) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode attempt %d/%d (trimmed=%s): %s", + accelerator.process_index, + flow_name, + attempt, + max_retries, + "yes" if temp_path is not None else "no", + _format_path_for_log(trimmed_path), + ) + tokens = vq_model_audio.encode(str(trimmed_path)).cpu() + _store_cached_audio_tokens(audio_path, tokens) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] (%s) audio encode success: %s", + accelerator.process_index, + flow_name, + path_repr, + ) + return tokens, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + finally: + if temp_path is not None and os.path.exists(temp_path): + try: + os.remove(temp_path) + except OSError: + pass + + def safe_video_get_code(video_tensor_sample: torch.Tensor, sample_index: int): + max_retries = int(getattr(config.dataset.preprocessing, "video_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "video_encode_retry_backoff_sec", 0.5)) + for attempt in range(1, max_retries + 1): + try: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] video encode request sample=%d attempt=%d/%d", + accelerator.process_index, + sample_index, + attempt, + max_retries, + ) + video_token = vq_model_image.get_code(video_tensor_sample) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] video encode success sample=%d", + accelerator.process_index, + sample_index, + ) + return video_token, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] v2t video encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + logger.warning( + "[rank %s] video encode retry sample=%d attempt=%d/%d error=%s", + accelerator.process_index, + sample_index, + attempt, + max_retries, + exc, + ) + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + + def safe_image_get_code(image_tensor_sample: torch.Tensor, sample_index: int): + max_retries = int(getattr(config.dataset.preprocessing, "image_encode_max_retries", 3)) + backoff = float(getattr(config.dataset.preprocessing, "image_encode_retry_backoff_sec", 0.5)) + for attempt in range(1, max_retries + 1): + try: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] image encode request sample=%d attempt=%d/%d", + accelerator.process_index, + sample_index, + attempt, + max_retries, + ) + if image_tensor_sample.dim() == 3: + image_tensor_sample = image_tensor_sample.unsqueeze(0) + elif image_tensor_sample.dim() != 4: + raise ValueError( + f"Expected image tensor with 3 or 4 dims, got shape {tuple(image_tensor_sample.shape)}" + ) + image_token = vq_model_image.get_code(image_tensor_sample) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[rank %s] image encode success sample=%d", + accelerator.process_index, + sample_index, + ) + return image_token, None + except Exception as exc: + if attempt == max_retries: + msg = ( + f"[Rank {accelerator.process_index}] s2s image encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + logger.warning( + "[rank %s] image encode retry sample=%d attempt=%d/%d error=%s", + accelerator.process_index, + sample_index, + attempt, + max_retries, + exc, + ) + sleep_time = min(backoff * attempt, 2.0) + time.sleep(sleep_time) + + def _decode_single_image(single_like): + if single_like is None: + return None + if isinstance(single_like, Image.Image): + return single_like.convert('RGB') + + data_bytes = None + + if isinstance(single_like, (bytes, bytearray)): + data_bytes = bytes(single_like) + elif isinstance(single_like, str): + try: + data_bytes = base64.b64decode(single_like) + except (binascii.Error, ValueError): + if os.path.isfile(single_like): + try: + with open(single_like, 'rb') as fh: + data_bytes = fh.read() + except OSError: + data_bytes = None + elif isinstance(single_like, dict): + binary_payload = single_like.get('bytes') + if binary_payload is not None: + data_bytes = binary_payload + else: + path_value = single_like.get('path') + if path_value and os.path.isfile(path_value): + try: + with open(path_value, 'rb') as fh: + data_bytes = fh.read() + except OSError: + data_bytes = None + + if data_bytes is None: + return None + + try: + with Image.open(BytesIO(data_bytes)) as img: + return img.convert('RGB') + except Exception: + return None + + def maybe_decode_image(image_like): + if isinstance(image_like, (list, tuple)): + return [_decode_single_image(item) for item in image_like] + return _decode_single_image(image_like) + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_i2i( + source_images: torch.FloatTensor, + target_images: torch.FloatTensor, + prompts: list[str], + is_train: bool = True, + ): + """Build masked i2i sequences from source/target image pairs.""" + + # Tokenize source/target images with VQ model and offset by text vocab size + source_tokens = vq_model_image.get_code(source_images) + len(uni_prompting.text_tokenizer) + target_tokens = vq_model_image.get_code(target_images) + len(uni_prompting.text_tokenizer) + + cond_dropout_prob = config.training.get( + "i2i_cond_dropout_prob", + config.training.cond_dropout_prob, + ) + + if is_train and torch.rand(1, device=source_tokens.device).item() < cond_dropout_prob: + effective_prompts = [''] * len(prompts) + masked_target_source = source_tokens + else: + effective_prompts = list(prompts) + masked_target_source = target_tokens + + masked_target_tokens, labels, _, mask_prob = mask_or_random_replace_tokens( + masked_target_source, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + + input_ids, attention_masks, labels = uni_prompting( + (effective_prompts, source_tokens, masked_target_tokens, labels), + 'i2i' + ) + + return input_ids, labels, mask_prob, attention_masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_v2s( + input_ids_v2s: torch.Tensor, + prompt_masks: torch.Tensor, + labels_v2s: torch.Tensor, + eps: float = 1e-3, + ): + b, l = input_ids_v2s.shape + device_local = input_ids_v2s.device + + p_mask = eps + (1.0 - eps) * torch.rand(b, device=device_local) + p_mask = p_mask.unsqueeze(1).expand(b, l) + + rand_mat = torch.rand((b, l), device=device_local) + answer_region = (prompt_masks == 0) + masked_indices = (rand_mat < p_mask) & answer_region + + noisy_batch = input_ids_v2s.clone() + noisy_batch[masked_indices] = mask_id + + answer_lengths = answer_region.sum(dim=-1, keepdim=True) + expanded_lengths = answer_lengths.expand(b, l) + + return noisy_batch.long(), labels_v2s.long(), p_mask, expanded_lengths.long() + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2s( + input_ids_s2s, prompt_masks, labels_s2s, eps=1e-3 + ): + b, l = input_ids_s2s.shape + t = torch.rand(b, device=input_ids_s2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_s2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_s2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_s2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_s2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + v2t_iterator: Optional[Iterator] = None + v2s_iterator: Optional[Iterator] = None + t2i_iterator: Optional[Iterator] = None + i2i_iterator: Optional[Iterator] = None + + def _next_from_v2t(): + nonlocal v2t_iterator + if train_dataloader_v2t is None: + return None + try: + return next(v2t_iterator) + except StopIteration: + v2t_iterator = iter(train_dataloader_v2t) + return next(v2t_iterator) + + def _next_from_v2s(): + nonlocal v2s_iterator + if train_dataloader_v2s is None: + return None + try: + return next(v2s_iterator) + except StopIteration: + v2s_iterator = iter(train_dataloader_v2s) + return next(v2s_iterator) + + def _next_from_t2i(): + nonlocal t2i_iterator + if train_dataloader_t2i is None: + return None + try: + return next(t2i_iterator) + except StopIteration: + t2i_iterator = iter(train_dataloader_t2i) + return next(t2i_iterator) + + def _next_from_i2i(): + nonlocal i2i_iterator + if train_dataloader_i2i is None: + return None + try: + return next(i2i_iterator) + except StopIteration: + i2i_iterator = iter(train_dataloader_i2i) + return next(i2i_iterator) + + def _next_from_s2t(): + nonlocal s2t_iterator + if train_dataloader_s2t is None: + return None + try: + return next(s2t_iterator) + except StopIteration: + s2t_iterator = iter(train_dataloader_s2t) + return next(s2t_iterator) + + def _next_from_t2s(): + nonlocal t2s_iterator + if train_dataloader_t2s is None: + return None + try: + return next(t2s_iterator) + except StopIteration: + t2s_iterator = iter(train_dataloader_t2s) + return next(t2s_iterator) + + def _next_from_lm(): + nonlocal lm_iterator + if train_dataloader_lm is None: + return None + try: + return next(lm_iterator) + except StopIteration: + lm_iterator = iter(train_dataloader_lm) + return next(lm_iterator) + + def _next_from_mmu(): + nonlocal mmu_iterator + if train_dataloader_mmu is None: + return None + try: + return next(mmu_iterator) + except StopIteration: + mmu_iterator = iter(train_dataloader_mmu) + return next(mmu_iterator) + + def _next_from_s2s(): + nonlocal s2s_iterator + if train_dataloader_s2s is None: + return None + try: + return next(s2s_iterator) + except StopIteration: + s2s_iterator = iter(train_dataloader_s2s) + return next(s2s_iterator) + + v2t_iterator = iter(train_dataloader_v2t) if train_dataloader_v2t is not None else None + v2s_iterator = iter(train_dataloader_v2s) if train_dataloader_v2s is not None else None + t2i_iterator = iter(train_dataloader_t2i) if train_dataloader_t2i is not None else None + i2i_iterator = iter(train_dataloader_i2i) if train_dataloader_i2i is not None else None + s2t_iterator = iter(train_dataloader_s2t) if train_dataloader_s2t is not None else None + t2s_iterator = iter(train_dataloader_t2s) if train_dataloader_t2s is not None else None + lm_iterator = iter(train_dataloader_lm) if train_dataloader_lm is not None else None + mmu_iterator = iter(train_dataloader_mmu) if train_dataloader_mmu is not None else None + s2s_iterator = iter(train_dataloader_s2s) if train_dataloader_s2s is not None else None + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + # Ensure all samplers reshuffle in a rank-consistent way each epoch + try: + if isinstance(sampler_t2i, DistributedSampler): + sampler_t2i.set_epoch(epoch) + if isinstance(sampler_i2i, DistributedSampler): + sampler_i2i.set_epoch(epoch) + if isinstance(sampler_v2t, DistributedSampler): + sampler_v2t.set_epoch(epoch) + if isinstance(sampler_v2s, DistributedSampler): + sampler_v2s.set_epoch(epoch) + if accelerator.num_processes > 1: + if sampler_s2t is not None: + sampler_s2t.set_epoch(epoch) + if sampler_t2s is not None: + sampler_t2s.set_epoch(epoch) + if sampler_s2s is not None: + sampler_s2s.set_epoch(epoch) + except Exception: + pass + model.train() + for step_in_epoch in range(num_update_steps_per_epoch): + skip_local = 0 + batch: dict[str, Any] = {} + + lm_batch = _next_from_lm() + if train_dataloader_lm is not None: + if lm_batch is None: + skip_local = 1 + batch["lm_flow"] = lm_batch + + mmu_batch = _next_from_mmu() + if train_dataloader_mmu is not None: + if mmu_batch is None: + skip_local = 1 + batch["mmu_flow"] = mmu_batch + + s2s_batch = _next_from_s2s() + if train_dataloader_s2s is not None: + if s2s_batch is None: + skip_local = 1 + batch["s2s_flow"] = s2s_batch + + skip_tensor = torch.tensor(skip_local, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction="sum") + if skip_sum.item() > 0: + if accelerator.is_main_process and skip_local: + logger.warning( + "Skipping global step %s due to empty base batch (lm/mmu/s2s).", + global_step, + ) + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + v2t_batch = None + v2s_batch = None + selected_v2_branch: Optional[str] = None + t2i_batch = None + i2i_batch = None + selected_x2i_branch: Optional[str] = None + + v2_choices: list[str] = [] + if train_dataloader_v2t is not None: + v2_choices.append("v2t") + if train_dataloader_v2s is not None: + v2_choices.append("v2s") + + if v2_choices: + if accelerator.num_processes > 1: + local_choice = random.randrange(len(v2_choices)) if accelerator.is_main_process else 0 + choice_idx = _broadcast_choice(local_choice, accelerator) + else: + choice_idx = random.randrange(len(v2_choices)) + + if choice_idx < 0 or choice_idx >= len(v2_choices): + if accelerator.is_main_process: + logger.warning( + "Received out-of-range v2 branch index %s for %s choices; clamping.", + choice_idx, + len(v2_choices), + ) + choice_idx = choice_idx % len(v2_choices) + + selected_v2_branch = v2_choices[choice_idx] + if selected_v2_branch == "v2t": + v2t_batch = _next_from_v2t() + else: + v2s_batch = _next_from_v2s() + + batch["v2t_flow"] = v2t_batch + batch["v2s_flow"] = v2s_batch + + # Initialize speech flows with empty placeholders; they will be populated if selected. + batch["s2t_flow"] = _empty_audio_batch() + batch["t2s_flow"] = _empty_audio_batch() + + speech_choices: list[str] = [] + if train_dataloader_s2t is not None: + speech_choices.append("s2t") + if train_dataloader_t2s is not None: + speech_choices.append("t2s") + + selected_speech_branch: Optional[str] = None + if speech_choices: + if accelerator.num_processes > 1: + local_choice = random.randrange(len(speech_choices)) if accelerator.is_main_process else 0 + choice_idx = _broadcast_choice(local_choice, accelerator) + else: + choice_idx = random.randrange(len(speech_choices)) + + if choice_idx < 0 or choice_idx >= len(speech_choices): + if accelerator.is_main_process: + logger.warning( + "Received out-of-range speech branch index %s for %s choices; clamping.", + choice_idx, + len(speech_choices), + ) + choice_idx = choice_idx % len(speech_choices) + + selected_speech_branch = speech_choices[choice_idx] + if selected_speech_branch == "s2t": + speech_batch = _next_from_s2t() + if speech_batch is None: + skip_local = 1 + else: + batch["s2t_flow"] = speech_batch + else: + speech_batch = _next_from_t2s() + if speech_batch is None: + skip_local = 1 + else: + batch["t2s_flow"] = speech_batch + + task_payloads: List[TaskBatchPayload] = [] + + x2i_choices: list[str] = [] + if train_dataloader_t2i is not None: + x2i_choices.append("t2i") + if train_dataloader_i2i is not None: + x2i_choices.append("i2i") + + if x2i_choices: + if accelerator.num_processes > 1: + local_choice = random.randrange(len(x2i_choices)) if accelerator.is_main_process else 0 + choice_idx = _broadcast_choice(local_choice, accelerator) + else: + choice_idx = random.randrange(len(x2i_choices)) + + if choice_idx < 0 or choice_idx >= len(x2i_choices): + if accelerator.is_main_process: + logger.warning( + "Received out-of-range x2i branch index %s for %s choices; clamping.", + choice_idx, + len(x2i_choices), + ) + choice_idx = choice_idx % len(x2i_choices) + + selected_x2i_branch = x2i_choices[choice_idx] + if selected_x2i_branch == "t2i": + t2i_batch = _next_from_t2i() + else: + i2i_batch = _next_from_i2i() + + # Synchronize skip decision across all ranks to avoid collective mismatches + required_flows = ["t2s_flow", "s2t_flow"] + if train_dataloader_lm is not None: + required_flows.append("lm_flow") + if train_dataloader_mmu is not None: + required_flows.append("mmu_flow") + if train_dataloader_s2s is not None: + required_flows.append("s2s_flow") + + local_skip = 0 + if selected_v2_branch == "v2t" and v2t_batch is None: + local_skip = 1 + elif selected_v2_branch == "v2s" and v2s_batch is None: + local_skip = 1 + else: + for key in required_flows: + if batch.get(key) is None: + local_skip = 1 + break + if selected_x2i_branch == "t2i": + if t2i_batch is None: + local_skip = 1 + else: + t2i_images = t2i_batch["t2i"].get("images") + if not isinstance(t2i_images, torch.Tensor) or t2i_images.shape[0] == 0: + local_skip = 1 + elif selected_x2i_branch == "i2i": + if i2i_batch is None: + local_skip = 1 + else: + i2i_sources = i2i_batch["i2i"].get("source_images") + i2i_targets = i2i_batch["i2i"].get("target_images") + if ( + not isinstance(i2i_sources, torch.Tensor) + or not isinstance(i2i_targets, torch.Tensor) + or i2i_sources.shape[0] == 0 + or i2i_targets.shape[0] == 0 + ): + local_skip = 1 + try: + skip_tensor = torch.tensor(local_skip, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction='sum') + should_skip = skip_sum.item() > 0 + except Exception: + # Fallback if reduce isn't available for any reason + should_skip = local_skip == 1 + + if should_skip: + if accelerator.is_main_process and local_skip: + logger.warning(f"Skipping step {global_step} (required multimodal batch missing) [synced]") + continue + + device = accelerator.device + batch_size_v2s = 0 + input_ids_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_v2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + # Text-to-image samples + batch_size_t2i = 0 + mask_prob = torch.tensor(0.0, device=device) + t2i_masks = torch.empty((0, 1), dtype=torch.long, device=device) + input_ids_t2i = torch.empty((0, 1), dtype=torch.long, device=device) + labels_t2i = torch.empty((0, 1), dtype=torch.long, device=device) + batch_size_i2i = 0 + mask_prob_i2i = torch.tensor(0.0, device=device) + input_ids_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + labels_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + attention_masks_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + + if selected_x2i_branch == "t2i" and t2i_batch is not None: + t2i_texts = t2i_batch["t2i"].get("texts", []) + t2i_images_tensor = t2i_batch["t2i"].get("images") + if isinstance(t2i_images_tensor, torch.Tensor) and t2i_images_tensor.shape[0] > 0: + t2i_images_tensor = t2i_images_tensor.to(device, non_blocking=True) + batch_size_t2i = t2i_images_tensor.shape[0] + ( + input_ids_t2i, + labels_t2i, + mask_prob, + _, + t2i_masks, + ) = prepare_inputs_and_labels(t2i_images_tensor, t2i_texts, config.training.min_masking_rate) + input_ids_t2i = input_ids_t2i.to(device, non_blocking=True) + labels_t2i = labels_t2i.to(device, non_blocking=True) + t2i_masks = t2i_masks.to(device, non_blocking=True) + if mask_prob.device != device: + mask_prob = mask_prob.to(device) + if batch_size_t2i > 0: + task_payloads.append( + _make_payload( + "t2i", + input_ids_t2i, + labels_t2i, + batch_size_t2i, + attention_mask=t2i_masks, + metadata={"mask_prob": mask_prob.clone()}, + weight=float(getattr(config.training, "t2i_coeff", 1.0)), + ) + ) + + if selected_x2i_branch == "i2i" and i2i_batch is not None: + i2i_prompts = i2i_batch["i2i"].get("prompts", []) + i2i_source_tensor = i2i_batch["i2i"].get("source_images") + i2i_target_tensor = i2i_batch["i2i"].get("target_images") + if ( + isinstance(i2i_source_tensor, torch.Tensor) + and isinstance(i2i_target_tensor, torch.Tensor) + and i2i_source_tensor.shape[0] > 0 + and i2i_target_tensor.shape[0] > 0 + ): + i2i_source_tensor = i2i_source_tensor.to(device, non_blocking=True) + i2i_target_tensor = i2i_target_tensor.to(device, non_blocking=True) + batch_size_i2i = i2i_source_tensor.shape[0] + ( + input_ids_i2i, + labels_i2i, + mask_prob_i2i, + attention_masks_i2i, + ) = prepare_inputs_and_labels_for_i2i( + i2i_source_tensor, + i2i_target_tensor, + i2i_prompts, + is_train=True, + ) + input_ids_i2i = input_ids_i2i.to(device, non_blocking=True) + labels_i2i = labels_i2i.to(device, non_blocking=True) + attention_masks_i2i = attention_masks_i2i.to(device, non_blocking=True) + if mask_prob_i2i.device != device: + mask_prob_i2i = mask_prob_i2i.to(device) + if batch_size_i2i > 0: + task_payloads.append( + _make_payload( + "i2i", + input_ids_i2i, + labels_i2i, + batch_size_i2i, + attention_mask=attention_masks_i2i, + metadata={"mask_prob": mask_prob_i2i.clone()}, + weight=float(getattr(config.training, "i2i_coeff", config.training.t2i_coeff)), + ) + ) + + # Language modeling samples + batch_size_lm = 0 + input_ids_lm = torch.empty((0, 1), dtype=torch.long, device=device) + labels_lm = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_lm = torch.empty((0, 1), dtype=torch.float32, device=device) + if train_dataloader_lm is not None: + lm_batch = batch.get("lm_flow") + if lm_batch is not None: + texts_lm = lm_batch["input_ids"] + batch_size_lm = len(texts_lm) + max_seq_for_lm = input_ids_t2i.shape[1] if batch_size_t2i > 0 else preproc_config.max_seq_length + input_ids_lm, labels_lm, p_mask_lm = prepare_inputs_and_labels_for_text(texts_lm, max_seq_for_lm) + input_ids_lm = input_ids_lm.to(device, non_blocking=True) + labels_lm = labels_lm.to(device, non_blocking=True) + p_mask_lm = p_mask_lm.to(device, non_blocking=True) + if batch_size_lm > 0: + task_payloads.append( + _make_payload( + "lm", + input_ids_lm, + labels_lm, + batch_size_lm, + p_mask=p_mask_lm, + weight=float(getattr(config.training, "lm_coeff", 1.0)), + ) + ) + + if isinstance(v2t_batch, dict): + video_tensor_text_raw = v2t_batch.get("video") + texts_vid = v2t_batch.get("captions", []) + else: + video_tensor_text_raw = None + texts_vid = [] + + if isinstance(v2s_batch, dict): + video_tensor_speech_raw = v2s_batch.get("video") + speech_items = v2s_batch.get("speech", []) + else: + video_tensor_speech_raw = None + speech_items = [] + + batch_size_v2t = video_tensor_text_raw.shape[0] if isinstance(video_tensor_text_raw, torch.Tensor) else 0 + batch_size_v2s = len(speech_items) + + video_tensor_text = ( + video_tensor_text_raw.to(device, non_blocking=True) + if isinstance(video_tensor_text_raw, torch.Tensor) + else torch.empty((0, 1, 1, 1, 1), device=device) + ) + video_tensor_speech = ( + video_tensor_speech_raw.to(device, non_blocking=True) + if isinstance(video_tensor_speech_raw, torch.Tensor) + else torch.empty((0, 1, 1, 1, 1), device=device) + ) + + batch_size_t2s_text = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + + s2s_batch = batch.get("s2s_flow") + batch_size_s2s = 0 + if s2s_batch is not None: + batch_size_s2s = len(s2s_batch.get("emova_sft", [])) + len(s2s_batch.get("instructs2s", [])) + + mmu_batch = batch.get("mmu_flow") + batch_size_mmu = 0 + image_tensor_list = [] + texts_image = [] + if mmu_batch is not None: + image_tensor_list = mmu_batch.get("images", []) + texts_image = mmu_batch.get("text", []) + batch_size_mmu = len(image_tensor_list) + + s2t_flow = batch.get("s2t_flow", {}) + t2s_flow = batch.get("t2s_flow", {}) + audio_paths_s2t, texts_s2t = s2t_flow.get("audio_path", []), s2t_flow.get("text", []) + audio_paths_t2s, texts_t2s = t2s_flow.get("audio_path", []), t2s_flow.get("text", []) + audio_tokens_s2t = s2t_flow.get("audio_tokens", []) + audio_tokens_t2s = t2s_flow.get("audio_tokens", []) + + if batch_size_s2t > 0 and batch_size_t2s_text > 0: + if accelerator.num_processes > 1: + local_choice = 0 if accelerator.is_main_process and random.random() < 0.5 else 1 + drop_t2s = _broadcast_choice(local_choice, accelerator) == 0 + else: + drop_t2s = random.random() < 0.5 + + if drop_t2s: + audio_paths_t2s = [] + texts_t2s = [] + batch_size_t2s_text = 0 + else: + audio_paths_s2t = [] + texts_s2t = [] + batch_size_s2t = 0 + else: + batch_size_s2t = len(audio_paths_s2t) + batch_size_t2s_text = len(audio_paths_t2s) + + active_x2i_branch = selected_x2i_branch or "none" + logger.info( + f"x2i_branch: {active_x2i_branch}, batch_size_t2i: {batch_size_t2i}, batch_size_i2i: {batch_size_i2i}, batch_size_lm: {batch_size_lm}, " + f"batch_size_v2t: {batch_size_v2t}, batch_size_v2s: {batch_size_v2s}, batch_size_t2s: {batch_size_t2s_text}, " + f"batch_size_s2t: {batch_size_s2t}, batch_size_s2s: {batch_size_s2s}, batch_size_mmu: {batch_size_mmu}" + ) + offset = speech_vocab_start + + data_time_m.update(time.time() - end) + + failure_messages = [] + step_failed = False + + input_ids_vid = torch.empty((0, 1), dtype=torch.long, device=device) + labels_vid = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_vid = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_vid = torch.empty((0, 1), dtype=torch.long, device=device) + + input_ids_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_v2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_v2s = torch.empty((0, 1), dtype=torch.long, device=device) + + if batch_size_v2t > 0: + video_token_list = [] + for vid_idx, video in enumerate(video_tensor_text): + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens + len(uni_prompting.text_tokenizer) + video_token_list.append(tokens.view(-1)) + + if not step_failed and video_token_list: + video_tokens_text = torch.stack(video_token_list, dim=0) + + texts_with_prompt: List[str] + is_vid_inst = False + if texts_vid and isinstance(texts_vid[0], (list, tuple)) and isinstance(texts_vid[0][0], dict): + is_vid_inst = True + vid_inst_prompt: List[str] = [] + vid_inst_answer: List[str] = [] + for conv in texts_vid: + human_msg = "" + assistant_msg = "" + for turn in conv: + role = turn.get("from") + value = turn.get("value", "") + if role == "human": + human_msg = value.replace("\n", "") + elif role == "gpt": + assistant_msg = value + vid_inst_prompt.append(human_msg) + vid_inst_answer.append(assistant_msg) + texts_with_prompt = [ + "<|start_header_id|>user<|end_header_id|>\n" + f"{vid_inst_prompt[i]}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n" + f"{vid_inst_answer[i]}" + for i in range(len(vid_inst_answer)) + ] + else: + prompt_v2t_selected = random.choice(V2T_INSTRUCTION) + texts_with_prompt = [ + "<|start_header_id|>user<|end_header_id|>\n" + f"{prompt_v2t_selected}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n" + f"{text if isinstance(text, str) else str(text)}" + for text in texts_vid + ] + + input_ids_vid_tmp, prompt_masks_vid, labels_vid_tmp = uni_prompting((video_tokens_text, texts_with_prompt), 'v2t') + input_ids_vid_tmp, labels_vid_tmp, p_mask_vid, answer_lengths_vid = prepare_inputs_and_labels_for_mmu( + input_ids_vid_tmp, prompt_masks_vid, labels_vid_tmp + ) + input_ids_vid = input_ids_vid_tmp.to(device, non_blocking=True) + labels_vid = labels_vid_tmp.to(device, non_blocking=True) + p_mask_vid = p_mask_vid.to(device, non_blocking=True) + answer_lengths_vid = answer_lengths_vid.to(device, non_blocking=True) + else: + batch_size_v2t = 0 + if batch_size_v2t > 0: + task_payloads.append( + _make_payload( + "v2t", + input_ids_vid, + labels_vid, + batch_size_v2t, + p_mask=p_mask_vid, + answer_lengths=answer_lengths_vid, + weight=float(getattr(config.training, "v2t_coeff", 1.0)), + ) + ) + + if batch_size_v2s > 0 and not step_failed: + all_audio_tokens: list[torch.Tensor] = [] + for speech_entry in speech_items: + if isinstance(speech_entry, torch.Tensor): + tokens = speech_entry.to(device, non_blocking=True) + else: + tokens, err = safe_audio_encode(speech_entry, "v2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + video_token_list_v2s: list[torch.Tensor] = [] + if not step_failed: + for vid_idx, video in enumerate(video_tensor_speech): + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens + len(uni_prompting.text_tokenizer) + video_token_list_v2s.append(tokens.view(-1)) + + if not step_failed and all_audio_tokens and video_token_list_v2s: + video_tokens_v2s = torch.stack(video_token_list_v2s, dim=0) + prompts_v2s = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(V2S_INSTRUCTION)}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for _ in range(batch_size_v2s) + ] + input_ids_v2s_tmp, prompt_masks_v2s, labels_v2s_tmp = uni_prompting( + (video_tokens_v2s, prompts_v2s, all_audio_tokens), 'v2s_ip' + ) + input_ids_v2s_tmp, labels_v2s_tmp, p_mask_v2s_tmp, answer_lengths_v2s_tmp = prepare_inputs_and_labels_for_v2s( + input_ids_v2s_tmp, prompt_masks_v2s, labels_v2s_tmp + ) + input_ids_v2s = input_ids_v2s_tmp.to(device, non_blocking=True) + labels_v2s = labels_v2s_tmp.to(device, non_blocking=True) + p_mask_v2s = p_mask_v2s_tmp.to(device, non_blocking=True) + answer_lengths_v2s = answer_lengths_v2s_tmp.to(device, non_blocking=True) + else: + batch_size_v2s = 0 + if batch_size_v2s > 0: + task_payloads.append( + _make_payload( + "v2s", + input_ids_v2s, + labels_v2s, + batch_size_v2s, + p_mask=p_mask_v2s, + answer_lengths=answer_lengths_v2s, + weight=float(getattr(config.training, "v2s_coeff", config.training.t2s_coeff)), + ) + ) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed and batch_size_s2t > 0: + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + if not audio_tokens_s2t: + audio_tokens_s2t = [None] * len(audio_paths_s2t) + elif len(audio_tokens_s2t) < len(audio_paths_s2t): + audio_tokens_s2t = list(audio_tokens_s2t) + [None] * (len(audio_paths_s2t) - len(audio_tokens_s2t)) + + for path, cached_tokens in zip(audio_paths_s2t, audio_tokens_s2t): + source = cached_tokens if cached_tokens is not None else path + tokens, err = safe_audio_encode(source, "s2t") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + # Preserve trailing EOS tokens in s2t targets for explicit prediction. + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + else: + input_ids_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_s2t = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + if batch_size_s2t > 0: + task_payloads.append( + _make_payload( + "s2t", + input_ids_s2t, + labels_s2t, + batch_size_s2t, + p_mask=p_mask_s2t, + answer_lengths=answer_lengths_s2t, + weight=float(getattr(config.training, "s2t_coeff", 1.0)), + ) + ) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed and batch_size_t2s_text > 0: + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + if not audio_tokens_t2s: + audio_tokens_t2s = [None] * len(audio_paths_t2s) + elif len(audio_tokens_t2s) < len(audio_paths_t2s): + audio_tokens_t2s = list(audio_tokens_t2s) + [None] * (len(audio_paths_t2s) - len(audio_tokens_t2s)) + + for path, cached_tokens in zip(audio_paths_t2s, audio_tokens_t2s): + source = cached_tokens if cached_tokens is not None else path + tokens, err = safe_audio_encode(source, "t2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + # Chat-style instruction formatting for T2S training + prompt = random.choice(prompt_t2s) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in texts_t2s + ] + + # input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + else: + input_ids_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_t2s = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + if batch_size_t2s_text > 0: + task_payloads.append( + _make_payload( + "t2s", + input_ids_t2s, + labels_t2s, + batch_size_t2s_text, + p_mask=p_mask_t2s, + answer_lengths=answer_lengths_t2s, + weight=float(getattr(config.training, "t2s_coeff", 1.0)), + ) + ) + + audio_user_ids_s2s: list[torch.Tensor] = [] + audio_asst_ids_s2s: list[torch.Tensor] = [] + image_token_blocks_s2s: list[Optional[torch.Tensor]] = [] + input_ids_s2s = None + labels_s2s = None + p_mask_s2s = None + answer_lengths_s2s = None + + if not step_failed and batch_size_s2s > 0 and s2s_batch is not None: + s2s_sample_counter = 0 + + emova_samples = s2s_batch.get("emova_sft", []) + for sample_idx, (usr_ids, asst_ids, image_like) in enumerate(emova_samples): + usr_tensor = torch.tensor(usr_ids, dtype=torch.long, device=accelerator.device).unsqueeze(0) + asst_tensor = torch.tensor(asst_ids, dtype=torch.long, device=accelerator.device).unsqueeze(0) + + audio_user_ids_s2s.append(usr_tensor + offset) + audio_asst_ids_s2s.append(asst_tensor + offset) + + decoded_payload = maybe_decode_image(image_like) + if isinstance(decoded_payload, (list, tuple)): + token_payloads = [] + for local_img in decoded_payload: + if local_img is None: + continue + pixel_values = image_transform(local_img, resolution=preproc_config.resolution).to(accelerator.device) + if pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + image_tokens_raw, err = safe_image_get_code(pixel_values, s2s_sample_counter) + if err is not None: + failure_messages.append(err) + step_failed = True + break + image_tokens = image_tokens_raw + len(uni_prompting.text_tokenizer) + token_payloads.append(image_tokens.view(1, -1).to(accelerator.device, non_blocking=True)) + s2s_sample_counter += 1 + if step_failed: + break + image_token_blocks_s2s.append(token_payloads if token_payloads else None) + else: + if decoded_payload is not None: + pixel_values = image_transform(decoded_payload, resolution=preproc_config.resolution).to(accelerator.device) + if pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + image_tokens_raw, err = safe_image_get_code(pixel_values, s2s_sample_counter) + if err is not None: + failure_messages.append(err) + step_failed = True + break + image_tokens = image_tokens_raw + len(uni_prompting.text_tokenizer) + image_token_blocks_s2s.append(image_tokens.view(1, -1).to(accelerator.device, non_blocking=True)) + s2s_sample_counter += 1 + else: + image_token_blocks_s2s.append(None) + + instruct_samples = [] if step_failed else s2s_batch.get("instructs2s", []) + if not step_failed: + for sample_idx, (usr_wav, asst_wav, _) in enumerate(instruct_samples): + user_tokens, err_usr = safe_audio_encode(usr_wav, "s2s-user") + if err_usr is not None: + failure_messages.append(err_usr) + step_failed = True + break + asst_tokens, err_asst = safe_audio_encode(asst_wav, "s2s-assistant") + if err_asst is not None: + failure_messages.append(err_asst) + step_failed = True + break + + user_tokens = user_tokens.to(accelerator.device, non_blocking=True) + asst_tokens = asst_tokens.to(accelerator.device, non_blocking=True) + if user_tokens.size(-1) > config.dataset.preprocessing.max_aud_length: + duration_tokens = config.dataset.preprocessing.max_aud_length + user_tokens = user_tokens[..., :duration_tokens] + if asst_tokens.size(-1) > config.dataset.preprocessing.max_aud_length: + duration_tokens = config.dataset.preprocessing.max_aud_length + asst_tokens = asst_tokens[..., :duration_tokens] + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + if asst_tokens.dim() == 1: + asst_tokens = asst_tokens.unsqueeze(0) + audio_user_ids_s2s.append(user_tokens + offset) + audio_asst_ids_s2s.append(asst_tokens + offset) + image_token_blocks_s2s.append(None) + s2s_sample_counter += 1 + + if not step_failed and audio_user_ids_s2s: + input_ids_s2s, prompt_masks_s2s, labels_s2s = uni_prompting( + (audio_user_ids_s2s, audio_asst_ids_s2s, image_token_blocks_s2s), + 's2s_ip' + ) + + input_ids_s2s, labels_s2s, p_mask_s2s, answer_lengths_s2s = prepare_inputs_and_labels_for_s2s( + input_ids_s2s, + prompt_masks_s2s, + labels_s2s, + ) + + if ( + answer_lengths_s2s is not None + and answer_lengths_s2s.numel() > 0 + and accelerator.is_main_process + ): + per_sample_lengths = answer_lengths_s2s[:, 0].detach().cpu() + lengths_list = [int(length) for length in per_sample_lengths.tolist()] + stats_msg = ( + f"min={int(per_sample_lengths.min().item())}, " + f"max={int(per_sample_lengths.max().item())}, " + f"mean={per_sample_lengths.float().mean().item():.2f}" + ) + logger.info("S2S answer lengths (no pad): %s | %s", lengths_list, stats_msg) + + if input_ids_s2s is None: + device = accelerator.device + input_ids_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_s2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + if input_ids_s2s.shape[0] > 0: + task_payloads.append( + _make_payload( + "s2s", + input_ids_s2s, + labels_s2s, + input_ids_s2s.shape[0], + p_mask=p_mask_s2s, + answer_lengths=answer_lengths_s2s, + weight=float(getattr(config.training, "s2s_coeff", config.training.t2s_coeff)), + ) + ) + + input_ids_mmu = None + labels_mmu = None + p_mask_mmu = None + answer_lengths_mmu = None + + if not step_failed and batch_size_mmu > 0: + batch_image_ids_list = [] + batch_text_ids = [] + + for b_idx, image_list in enumerate(image_tensor_list): + per_img_ids = [] + for j, img in enumerate(image_list): + tok, err = safe_image_get_code( + img.to(accelerator.device, non_blocking=True), + sample_index=j + ) + if err is not None: + failure_messages.append(err) + step_failed = True + break + + tok = tok.to(accelerator.device, non_blocking=True).view(-1).long() + tok = tok + len(uni_prompting.text_tokenizer) + per_img_ids.append(tok) + + if step_failed: + break + + batch_image_ids_list.append(per_img_ids) + text_ids = uni_prompting.text_tokenizer.encode(texts_image[b_idx], add_special_tokens=False) + batch_text_ids.append(text_ids) + + if not step_failed: + input_ids_mmu, prompt_masks_mmu, labels_mmu = uni_prompting.mmu_mult_prompt( + batch_image_ids_list=batch_image_ids_list, + batch_text_ids=batch_text_ids, + ) + + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths_mmu + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks_mmu, labels_mmu) + + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + labels_mmu = labels_mmu.to(accelerator.device, non_blocking=True) + p_mask_mmu = p_mask_mmu.to(accelerator.device, non_blocking=True) + answer_lengths_mmu = answer_lengths_mmu.to(accelerator.device, non_blocking=True) + + if batch_size_mmu == 0 or input_ids_mmu is None: + input_ids_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_mmu = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + else: + task_payloads.append( + _make_payload( + "mmu", + input_ids_mmu, + labels_mmu, + batch_size_mmu, + p_mask=p_mask_mmu, + answer_lengths=answer_lengths_mmu, + weight=float(getattr(config.training, "mmu_coeff", 0.0)), + ) + ) + if not step_failed: + total_batch_size_t2s = batch_size_t2s_text + else: + total_batch_size_t2s = batch_size_t2s_text + + failure_tensor = torch.tensor(1 if step_failed else 0, device=accelerator.device, dtype=torch.int32) + failure_sum = accelerator.reduce(failure_tensor, reduction='sum') + if failure_sum.item() > 0: + if accelerator.is_main_process and failure_messages: + for msg in failure_messages: + logger.warning(f"Skipping global step {global_step} due to preprocessing failure: {msg}") + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + # -------------------------------------------------------------------------------- + # for name, tensor in [ + # ("t2i", input_ids_t2i), + # ("i2i", input_ids_i2i), + # ("lm", input_ids_lm), + # ("mmu", input_ids_mmu), + # ("vid", input_ids_vid), + # ("s2t", input_ids_s2t), + # ("s2s", input_ids_s2s), + # ("t2s", input_ids_t2s), + # ]: + # if tensor is not None: + # print(f"{name:>4}: shape={getattr(tensor, 'shape', None)}, len={len(tensor) if hasattr(tensor, '__len__') else 'N/A'}") + + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + seq_lengths = [payload.input_ids.shape[1] for payload in task_payloads if payload.input_ids.shape[0] > 0] + # Ensure we still account for placeholder tensors that were required for downstream logic + if input_ids_t2i.shape[0] == 0: + seq_lengths.append(input_ids_t2i.shape[1]) + if input_ids_i2i.shape[0] == 0: + seq_lengths.append(input_ids_i2i.shape[1]) + if input_ids_lm.shape[0] == 0: + seq_lengths.append(input_ids_lm.shape[1]) + if input_ids_vid.shape[0] == 0: + seq_lengths.append(input_ids_vid.shape[1]) + if input_ids_v2s.shape[0] == 0: + seq_lengths.append(input_ids_v2s.shape[1]) + if input_ids_s2t.shape[0] == 0: + seq_lengths.append(input_ids_s2t.shape[1]) + if input_ids_t2s.shape[0] == 0: + seq_lengths.append(input_ids_t2s.shape[1]) + if input_ids_s2s.shape[0] == 0: + seq_lengths.append(input_ids_s2s.shape[1]) + if input_ids_mmu.shape[0] == 0: + seq_lengths.append(input_ids_mmu.shape[1]) + seq_lengths = [length for length in seq_lengths if length > 0] + if not seq_lengths: + seq_lengths.append(1) + max_len = max(seq_lengths) + + # 3. Pad all tensors to the max_len + input_ids_t2i = pad_tensor(input_ids_t2i, max_len, pad_token_id) + labels_t2i = pad_tensor(labels_t2i, max_len, -100) + if t2i_masks.shape[0] > 0: + t2i_masks = pad_tensor(t2i_masks.long(), max_len, 0) + else: + t2i_masks = torch.empty((0, max_len), dtype=torch.long, device=device) + + input_ids_i2i = pad_tensor(input_ids_i2i, max_len, pad_token_id) + labels_i2i = pad_tensor(labels_i2i, max_len, -100) + if attention_masks_i2i.shape[0] > 0: + attention_masks_i2i = pad_tensor(attention_masks_i2i.long(), max_len, 0) + else: + attention_masks_i2i = torch.empty((0, max_len), dtype=torch.long, device=device) + + + input_ids_lm = pad_tensor(input_ids_lm, max_len, pad_token_id) + labels_lm = pad_tensor(labels_lm, max_len, -100) + p_mask_lm = pad_tensor(p_mask_lm, max_len, 1.0) + + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_v2s = pad_tensor(input_ids_v2s, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + input_ids_s2s = pad_tensor(input_ids_s2s, max_len, pad_token_id) + input_ids_mmu = pad_tensor(input_ids_mmu, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_v2s = pad_tensor(labels_v2s, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + labels_s2s = pad_tensor(labels_s2s, max_len, -100) + labels_mmu = pad_tensor(labels_mmu, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_v2s = pad_tensor(p_mask_v2s, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + p_mask_s2s = pad_tensor(p_mask_s2s, max_len, 1.0) + p_mask_mmu = pad_tensor(p_mask_mmu, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_v2s = pad_answer_lengths(answer_lengths_v2s, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + answer_lengths_s2s = pad_answer_lengths(answer_lengths_s2s, max_len) + answer_lengths_mmu = pad_answer_lengths(answer_lengths_mmu, max_len) + + input_ids = torch.cat(( + input_ids_t2i, + input_ids_i2i, + input_ids_lm, + input_ids_mmu, + input_ids_vid, + input_ids_v2s, + input_ids_s2t, + input_ids_s2s, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_t2i, + labels_i2i, + labels_lm, + labels_mmu, + labels_vid, + labels_v2s, + labels_s2t, + labels_s2s, + labels_t2s + ), dim=0) + + # w/o texts and images + if batch_size_lm == 0: + p_mask_lm = torch.empty((0, max_len), dtype=torch.float32, device=device) + if batch_size_t2i == 0 and t2i_masks.shape[0] == 0: + t2i_masks = torch.empty((0, max_len), dtype=torch.long, device=device) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Labels: {}".format(labels)) + + logger.info("Input ids shape: {}".format(input_ids.shape)) + # with accelerator.accumulate(model): + logits, loss_t2i, loss_i2i, loss_lm, loss_mmu, loss_vid, loss_v2s, loss_s2t, loss_s2s, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_i2i=batch_size_i2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_v2s=batch_size_v2s, + batch_size_s2t=batch_size_s2t, + batch_size_s2s=batch_size_s2s, + batch_size_t2s=total_batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + attention_masks_i2i=attention_masks_i2i, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_v2s=p_mask_v2s, + p_mask_s2t=p_mask_s2t, + p_mask_s2s=p_mask_s2s, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_v2s=answer_lengths_v2s, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_s2s=answer_lengths_s2s, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + text_vocab_size_override=len(uni_prompting.text_tokenizer), + ) + + if batch_size_t2i == 0: + loss_t2i = loss_t2i.new_zeros(()) + if batch_size_i2i == 0: + loss_i2i = loss_i2i.new_zeros(()) + + # Gather the losses across all processes for logging (use reduce to avoid shape mismatches) + avg_loss_t2i = accelerator.reduce(loss_t2i, reduction='mean') + avg_loss_i2i = accelerator.reduce(loss_i2i, reduction='mean') + avg_loss_lm = accelerator.reduce(loss_lm, reduction='mean') + avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + avg_loss_vid = accelerator.reduce(loss_vid, reduction='mean') + avg_loss_v2s = accelerator.reduce(loss_v2s, reduction='mean') + avg_loss_s2t = accelerator.reduce(loss_s2t, reduction='mean') + avg_loss_s2s = accelerator.reduce(loss_s2s, reduction='mean') + if not torch.isfinite(loss_t2s): + if labels_t2s.numel() > 0: + speech_vocab_end = speech_vocab_start + audio_codebook_size + valid_mask = labels_t2s != -100 + if valid_mask.any(): + labels_valid = labels_t2s[valid_mask] + below_count = (labels_valid < speech_vocab_start).sum().item() + above_count = (labels_valid >= speech_vocab_end).sum().item() + labels_min = labels_valid.min().item() + labels_max = labels_valid.max().item() + else: + below_count = above_count = 0 + labels_min = labels_max = -100 + p_mask_min = p_mask_t2s.min().item() if p_mask_t2s.numel() > 0 else float("nan") + ans_len_min = ( + answer_lengths_t2s.min().item() + if answer_lengths_t2s.numel() > 0 + else float("nan") + ) + accelerator.print( + "[t2s NaN debug] " + f"rank={accelerator.process_index} step={global_step} " + f"slice=({speech_vocab_start}, {speech_vocab_end}) " + f"labels_min={labels_min} labels_max={labels_max} " + f"below_slice={below_count} above_slice={above_count} " + f"p_mask_min={p_mask_min} answer_len_min={ans_len_min}" + ) + accelerator.print( + f"[rank {accelerator.process_index}] t2s loss became NaN/Inf at global step {global_step} " + f"(local value: {loss_t2s.item()})" + ) + logger.warning( + "[rank %s] t2s loss became NaN/Inf at global step %s (local value: %s)", + accelerator.process_index, + global_step, + loss_t2s.item(), + ) + avg_loss_t2s = accelerator.reduce(loss_t2s, reduction='mean') + if not torch.isfinite(avg_loss_t2s): + accelerator.print( + f"[rank {accelerator.process_index}] reduced t2s loss NaN/Inf at global step {global_step} " + f"(value after all-reduce: {avg_loss_t2s.item()})" + ) + if accelerator.is_main_process: + logger.warning( + "Reduced t2s loss became NaN/Inf at global step %s (value after all-reduce: %s)", + global_step, + avg_loss_t2s.item(), + ) + + mmu_coeff = getattr(config.training, "mmu_coeff", 0.0) + i2i_coeff = getattr(config.training, "i2i_coeff", config.training.t2i_coeff) + s2s_coeff = getattr(config.training, "s2s_coeff", config.training.t2s_coeff) + v2s_coeff = getattr(config.training, "v2s_coeff", config.training.t2s_coeff) + loss = ( + config.training.t2i_coeff * loss_t2i + + i2i_coeff * loss_i2i + + config.training.lm_coeff * loss_lm + + mmu_coeff * loss_mmu + + config.training.v2t_coeff * loss_vid + + v2s_coeff * loss_v2s + + config.training.s2t_coeff * loss_s2t + + s2s_coeff * loss_s2s + + config.training.t2s_coeff * loss_t2s + ) + + if batch_size_t2i > 0: + local_masking_rate = mask_prob.float().mean() + else: + local_masking_rate = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate = accelerator.reduce(local_masking_rate, reduction='mean') + + if batch_size_i2i > 0: + local_masking_rate_i2i = mask_prob_i2i.float().mean() + else: + local_masking_rate_i2i = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_i2i = accelerator.reduce(local_masking_rate_i2i, reduction='mean') + + if batch_size_s2s > 0 and p_mask_s2s.numel() > 0: + local_masking_rate_s2s = p_mask_s2s.float().mean() + else: + local_masking_rate_s2s = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_s2s = accelerator.reduce(local_masking_rate_s2s, reduction='mean') + + if batch_size_v2s > 0 and p_mask_v2s.numel() > 0: + local_masking_rate_v2s = p_mask_v2s.float().mean() + else: + local_masking_rate_v2s = torch.tensor(0.0, device=accelerator.device) + avg_masking_rate_v2s = accelerator.reduce(local_masking_rate_v2s, reduction='mean') + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "avg_masking_rate_i2i": avg_masking_rate_i2i.item(), + "avg_masking_rate_v2s": avg_masking_rate_v2s.item(), + "avg_masking_rate_s2s": avg_masking_rate_s2s.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + + loss_entries = [ + ("step_loss_t2i", avg_loss_t2i), + ("step_loss_i2i", avg_loss_i2i), + ("step_loss_lm", avg_loss_lm), + ("step_loss_mmu", avg_loss_mmu), + ("step_loss_vid", avg_loss_vid), + ("step_loss_v2s", avg_loss_v2s), + ("step_loss_s2t", avg_loss_s2t), + ("step_loss_s2s", avg_loss_s2s), + ("step_loss_t2s", avg_loss_t2s), + ] + + loss_log_parts = [] + for key, value in loss_entries: + loss_value = value.item() + if loss_value != 0.0: + logs[key] = loss_value + loss_log_parts.append(f"{key.replace('step_', '').capitalize()}: {loss_value:0.4f}") + + accelerator.log(logs, step=global_step + 1) + + loss_str = " ".join(loss_log_parts) + logger.info( + "Step: %d %s Data (t): %.4f, %.2f/s/gpu Batch (t): %.4f LR: %.6f" + % ( + global_step + 1, + loss_str, + data_time_m.val, + samples_per_second_per_gpu, + batch_time_m.val, + lr_scheduler.get_last_lr()[0], + ) + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # if global_step == 0 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + if (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_omada_inst_test.py b/MMaDA/training/train_omada_inst_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ae12aa98fa999a32a563a1e5e72068636215c9e1 --- /dev/null +++ b/MMaDA/training/train_omada_inst_test.py @@ -0,0 +1,3147 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import binascii +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import torch.nn.functional as F +import shutil +import time +import cv2 +import glob +import random +import contextlib +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union +import csv +import numpy as np +from PIL import Image +from io import BytesIO +from omegaconf import OmegaConf, DictConfig +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import ( + SpeechTextDataset, + MixedSpeechTextDataset, + Speech2SpeechDataset, + TextImageInterleavedDataset, + load_video_mp4, + VideoCaptionDataset, + S2T_INSTRUCTION, + T2S_INSTRUCTION, + s2s_collate_fn, +) +# import librosa + +from training.data import Text2ImageDataset, HQEditX2IDataset, HFInstructionTextDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } + + +def collate_fn_mmu_mult(batch): + return { + 'images': [item['images'] for item in batch], + 'text': [item['text'] for item in batch], + } + + +def collate_fn_x2i(batch): + t2i_texts: list[str] = [] + t2i_images: list[torch.Tensor] = [] + + i2i_prompts: list[str] = [] + i2i_source_images: list[torch.Tensor] = [] + i2i_target_images: list[torch.Tensor] = [] + + t2i_ref: Optional[torch.Tensor] = None + i2i_source_ref: Optional[torch.Tensor] = None + i2i_target_ref: Optional[torch.Tensor] = None + + for sample in batch: + input_prompt = sample.get("input_prompt") + output_prompt = sample.get("output_prompt") + edit_prompt = sample.get("edit_prompt") + inverse_prompt = sample.get("inverse_prompt") + input_image = sample.get("input_image") + output_image = sample.get("output_image") + + if isinstance(input_image, torch.Tensor) and t2i_ref is None: + t2i_ref = input_image + if isinstance(output_image, torch.Tensor) and t2i_ref is None: + t2i_ref = output_image + + t2i_candidates: list[tuple[str, torch.Tensor]] = [] + if input_prompt and isinstance(input_image, torch.Tensor): + t2i_candidates.append((input_prompt, input_image)) + if output_prompt and isinstance(output_image, torch.Tensor): + t2i_candidates.append((output_prompt, output_image)) + i2i_candidates: list[tuple[str, torch.Tensor, torch.Tensor]] = [] + if edit_prompt and isinstance(input_image, torch.Tensor) and isinstance(output_image, torch.Tensor): + i2i_candidates.append((edit_prompt, input_image, output_image)) + if inverse_prompt and isinstance(input_image, torch.Tensor) and isinstance(output_image, torch.Tensor): + i2i_candidates.append((inverse_prompt, output_image, input_image)) + + branch_choices: list[tuple[str, tuple]] = [] + if t2i_candidates: + branch_choices.append(("t2i", random.choice(t2i_candidates))) + if i2i_candidates: + branch_choices.append(("i2i", random.choice(i2i_candidates))) + + if branch_choices: + branch, payload = random.choice(branch_choices) + if branch == "t2i": + chosen_prompt, chosen_image = payload # type: ignore[misc] + t2i_texts.append(chosen_prompt) + t2i_images.append(chosen_image) + else: + prompt_choice, src_img, tgt_img = payload # type: ignore[misc] + if i2i_source_ref is None: + i2i_source_ref = src_img + if i2i_target_ref is None: + i2i_target_ref = tgt_img + i2i_prompts.append(prompt_choice) + i2i_source_images.append(src_img) + i2i_target_images.append(tgt_img) + + def stack_or_empty(tensors: list[torch.Tensor], ref: Optional[torch.Tensor]) -> torch.Tensor: + if tensors: + return torch.stack(tensors, dim=0) + if ref is not None: + c, h, w = ref.shape[-3:] + return torch.empty((0, c, h, w), dtype=ref.dtype) + return torch.empty((0, 3, 0, 0), dtype=torch.float32) + + return { + "t2i": { + "texts": t2i_texts, + "images": stack_or_empty(t2i_images, t2i_ref), + }, + "i2i": { + "prompts": i2i_prompts, + "source_images": stack_or_empty(i2i_source_images, i2i_source_ref if i2i_source_ref is not None else t2i_ref), + "target_images": stack_or_empty(i2i_target_images, i2i_target_ref if i2i_target_ref is not None else t2i_ref), + }, + } + +def collate_fn_video_caption(batch): + + batch = [item for item in batch if item is not None] + if len(batch) == 0: + return None + + frame_list = [] + input_ids_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + input_ids_list.append(item['caption']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "captions": input_ids_list # input_ids (B, seq_len) + } + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + for item in batch: + path = item['audio_path'] + tokens = vq_model_audio.encode(path) + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def add_gumbel_noise(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + +@ torch.no_grad() +def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., + cfg_scale=0., remasking='low_confidence', mask_id=126336, attention_mask=None): + ''' + Args: + model: Mask predictor. + prompt: A tensor of shape (B, L), where B is batch size. + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + ''' + if attention_mask is not None and 0.0 in attention_mask: + attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) + print(f"attention_bias: {attention_bias}") + else: + attention_bias = None + batch_size = prompt.shape[0] + x = torch.full((batch_size, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x[:, :prompt.shape[1]] = prompt.clone() + + prompt_index = (x != mask_id) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = model(x_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = model(x, attention_bias=attention_bias).logits + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + + if remasking == 'low_confidence': + p = F.softmax(logits.to(torch.float64), dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + # print(confidence.shape) + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + return x + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + +def _resolve_mask_schedule(config): + schedule_cfg = getattr(config, "mask_schedule", None) + if isinstance(schedule_cfg, DictConfig): + schedule_name = getattr(schedule_cfg, "schedule", None) + params_cfg = getattr(schedule_cfg, "params", None) + elif isinstance(schedule_cfg, dict): + schedule_name = schedule_cfg.get("schedule") + params_cfg = schedule_cfg.get("params") + else: + schedule_name = None + params_cfg = None + if schedule_name is None: + schedule_name = config.training.get("mask_schedule", "cosine") + params = {} + if params_cfg is not None: + if isinstance(params_cfg, DictConfig): + params = OmegaConf.to_container(params_cfg, resolve=True) or {} + elif isinstance(params_cfg, dict): + params = dict(params_cfg) + else: + params = params_cfg + if not isinstance(params, dict): + params = {} + return get_mask_schedule(schedule_name, **params) +def _tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image: + image = torch.clamp((image_tensor.detach().cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) + array = (image.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) + return Image.fromarray(array) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2I EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +@torch.no_grad() +def evaluate_t2i(model, vq_model_image, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2I Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + prompts_file = "/home/work/AIDAS/MMaDA/validation_prompts/quantative.txt" + if not prompts_file: + logger.warning("No validation prompts file configured. Skipping T2I evaluation.") + return + prompts_path = Path(prompts_file) + if not prompts_path.is_absolute(): + prompts_path = Path.cwd() / prompts_path + if not prompts_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / prompts_file + if alt_path.exists(): + prompts_path = alt_path + try: + with open(prompts_path, "r", encoding="utf-8") as handle: + prompts = [line.strip() for line in handle if line.strip()] + except OSError as exc: + logger.warning(f"Failed to read validation prompts from '{prompts_path}': {exc}. Skipping T2I evaluation.") + return + if not prompts: + logger.warning("Validation prompts file is empty. Skipping T2I evaluation.") + return + max_samples = getattr(config.experiment, "eval_num_t2i_samples", 8) + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 8 + prompts = prompts[:max_samples] + mask_schedule = _resolve_mask_schedule(config) + mask_token_id = unwrapped_model.config.mask_token_id + seq_len = getattr(getattr(config.model, "omada", None), "num_vq_tokens", None) + if seq_len is None: + seq_len = getattr(unwrapped_model.config, "num_vq_tokens", None) + if seq_len is None: + logger.warning("Unable to determine image token sequence length. Skipping T2I evaluation.") + return + seq_len = int(seq_len) + device = accelerator.device + image_tokens = torch.full((len(prompts), seq_len), mask_token_id, dtype=torch.long, device=device) + input_ids, attention_mask = uni_prompting((prompts, image_tokens), 't2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids, uncond_attention_mask = None, None + cfg_scale = 0.0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + use_autocast = accelerator.device.type == "cuda" and accelerator.mixed_precision != "no" + autocast_ctx = torch.autocast("cuda", dtype=weight_dtype) if use_autocast else contextlib.nullcontext() + with autocast_ctx: + gen_token_ids = unwrapped_model.t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=seq_len, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, min=0, max=unwrapped_model.config.codebook_size - 1) + images = vq_model_image.decode_code(gen_token_ids) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images = images.permute(0, 2, 3, 1).cpu().numpy() * 255.0 + pil_images = [Image.fromarray(img.astype(np.uint8)) for img in images] + wandb_images = [wandb.Image(img, caption=prompt) for img, prompt in zip(pil_images, prompts)] + accelerator.log({"eval/t2i_samples": wandb_images}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ I2I EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +@torch.no_grad() +def evaluate_i2i(model, vq_model_image, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running I2I Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + dataset_cfg = getattr(config.dataset, "params", {}) + resolution = getattr(dataset_cfg, "resolution", 256) + + def _cfg_to_dict(cfg): + if cfg is None: + return None + if isinstance(cfg, dict): + return cfg + if isinstance(cfg, DictConfig): + return OmegaConf.to_container(cfg, resolve=True) + return cfg + + eval_dataset = None + eval_dataset_name = "unknown" + openimage_cfg_raw = None + if isinstance(dataset_cfg, dict): + openimage_cfg_raw = dataset_cfg.get("openimage_i2i") + else: + openimage_cfg_raw = getattr(dataset_cfg, "openimage_i2i", None) + openimage_cfg = _cfg_to_dict(openimage_cfg_raw) + if openimage_cfg: + try: + eval_dataset = OpenImageI2IDataset( + resolution=resolution, + image_root=openimage_cfg.get("image_root"), + sft_jsonl=openimage_cfg.get("sft_jsonl"), + pref_jsonl=openimage_cfg.get("pref_jsonl"), + multi_turn_jsonl=openimage_cfg.get("multi_turn_jsonl"), + prefer_summarized_text=bool(openimage_cfg.get("prefer_summarized_text", True)), + pref_positive_only=bool(openimage_cfg.get("pref_positive_only", True)), + skip_missing=bool(openimage_cfg.get("skip_missing", True)), + max_samples_per_source=openimage_cfg.get("max_samples_per_source"), + max_total_samples=openimage_cfg.get("max_total_samples"), + seed=openimage_cfg.get("seed"), + ) + eval_dataset_name = "OpenImage I2I" + logger.info("Using OpenImage I2I dataset for evaluation (samples=%d).", len(eval_dataset)) + except Exception as exc: + logger.warning("Failed to build OpenImage I2I eval dataset (%s); falling back to HQ-Edit.", exc) + eval_dataset = None + if eval_dataset is None: + eval_dataset = HQEditX2IDataset(split='train', resolution=resolution) + eval_dataset_name = "HQ-Edit" + logger.info("Using HQ-Edit dataset for I2I evaluation.") + + if len(eval_dataset) == 0: + logger.warning("%s evaluation split is empty. Skipping I2I evaluation.", eval_dataset_name) + return + max_samples = getattr(config.experiment, "eval_num_i2i_samples", 8) + + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 8 + num_samples = min(max_samples, len(eval_dataset)) + if len(eval_dataset) <= num_samples: + sample_indices = list(range(len(eval_dataset))) + else: + sample_indices = random.sample(range(len(eval_dataset)), num_samples) + samples = [eval_dataset[i] for i in sample_indices] + prompts = [] + original_tensors = [] + target_tensors = [] + for sample in samples: + prompts.append(sample.get("edit_prompt") or sample.get("output_prompt") or "") + original_tensors.append(sample["input_image"]) + target_tensors.append(sample["output_image"]) + original_images = torch.stack(original_tensors, dim=0).to(accelerator.device) + original_tokens = vq_model_image.get_code(original_images) + len(uni_prompting.text_tokenizer) + seq_len = original_tokens.shape[-1] + mask_token_id = unwrapped_model.config.mask_token_id + placeholder = torch.full((num_samples, seq_len), mask_token_id, dtype=torch.long, device=accelerator.device) + input_ids, attention_mask = uni_prompting((prompts, original_tokens, placeholder), 'i2i_gen') + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting( + ([''] * num_samples, original_tokens, placeholder), 'i2i_gen' + ) + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids, uncond_attention_mask = None, None + cfg_scale = 0.0 + mask_schedule = _resolve_mask_schedule(config) + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + use_autocast = accelerator.device.type == "cuda" and accelerator.mixed_precision != "no" + autocast_ctx = torch.autocast("cuda", dtype=weight_dtype) if use_autocast else contextlib.nullcontext() + with autocast_ctx: + gen_token_ids = unwrapped_model.i2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + seq_len=seq_len, + uni_prompting=uni_prompting, + config=config, + ) + gen_token_ids = torch.clamp(gen_token_ids, min=0, max=unwrapped_model.config.codebook_size - 1) + generated_images = vq_model_image.decode_code(gen_token_ids) + generated_images = torch.clamp((generated_images + 1.0) / 2.0, min=0.0, max=1.0) + gen_images_pil = [Image.fromarray((img.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)) for img in generated_images] + source_pil = [_tensor_to_pil(tensor) for tensor in original_tensors] + target_pil = [_tensor_to_pil(tensor) for tensor in target_tensors] + log_resolution = getattr(config.experiment, "eval_image_log_resolution", 512) + wandb_images = [] + for prompt, src, pred, tgt in zip(prompts, source_pil, gen_images_pil, target_pil): + composite = Image.new('RGB', (log_resolution * 3, log_resolution)) + src_resized = src.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + pred_resized = pred.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + tgt_resized = tgt.resize((log_resolution, log_resolution), Image.Resampling.LANCZOS) + composite.paste(src_resized, (0, 0)) + composite.paste(pred_resized, (log_resolution, 0)) + composite.paste(tgt_resized, (log_resolution * 2, 0)) + wandb_images.append(wandb.Image(composite, caption=f"Prompt: {prompt}")) + accelerator.log({"eval/i2i_samples": wandb_images}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running S2S Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + dataset_cfg = getattr(config.dataset, "params", {}) + if isinstance(dataset_cfg, DictConfig): + dataset_cfg = dataset_cfg + s2s_eval_dir = getattr(dataset_cfg, "s2s_eval_dir", "MMaDA/validation_prompts/s2s") + + base_path = Path(s2s_eval_dir) + if not base_path.is_absolute(): + base_path = Path.cwd() / base_path + if not base_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / s2s_eval_dir + if alt_path.exists(): + base_path = alt_path + + if not base_path.exists(): + logger.warning(f"S2S evaluation directory '{s2s_eval_dir}' not found. Skipping S2S evaluation.") + return + + audio_exts = {".wav", ".flac", ".mp3", ".ogg", ".m4a"} + wav_files = sorted(p for p in base_path.iterdir() if p.is_file() and p.suffix.lower() in audio_exts) + if not wav_files: + logger.warning(f"No audio files found in '{base_path}'. Skipping S2S evaluation.") + return + + condition = getattr(dataset_cfg, "s2s_eval_condition", "gender-female_emotion-neutral_speed-normal_pitch-normal") + mask_token_id = unwrapped_model.config.mask_token_id + codebook_size = int(getattr(config.model.omada, "codebook_size", 8192)) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_size + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + if audio_codebook_size <= 0: + logger.warning("Computed audio codebook size is non-positive. Skipping S2S evaluation.") + return + + a_tokens = uni_prompting.text_tokenizer("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", return_tensors="pt").input_ids + soa_len = uni_prompting.sptids_dict['<|soa|>'].numel() + eoa_len = uni_prompting.sptids_dict['<|eoa|>'].numel() + asst_header_len = a_tokens.shape[1] + max_audio_len = getattr(uni_prompting, "max_audio_len", 256) + max_generatable = max(1, max_audio_len - (asst_header_len + soa_len + eoa_len)) + + offset = len(uni_prompting.text_tokenizer) + codebook_size + device = accelerator.device + + output_root = Path(config.experiment.output_dir) / "eval_s2s" / f"step_{global_step}" + output_root.mkdir(parents=True, exist_ok=True) + + table = wandb.Table(columns=["Audio ID", "Source Audio", "Generated Audio", "Token Count"]) + + for audio_path in wav_files: + try: + user_tokens = vq_model_audio.encode(str(audio_path)) + except Exception as exc: + logger.error(f"Failed to encode '{audio_path}': {exc}") + continue + + if not isinstance(user_tokens, torch.Tensor): + user_tokens = torch.tensor(user_tokens) + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + + user_tokens = user_tokens.to(device=device, dtype=torch.long) + if user_tokens.numel() == 0: + logger.warning(f"Encoded audio from '{audio_path}' produced no tokens. Skipping sample.") + continue + + assistant_len = max_generatable + if assistant_len <= 0: + logger.warning(f"Assistant placeholder length for '{audio_path}' is non-positive. Skipping sample.") + continue + + user_shifted = user_tokens + offset + assistant_placeholder = torch.full( + (1, assistant_len), + mask_token_id, + dtype=torch.long, + device=device, + ) + + input_ids, _prompt_masks, _ = uni_prompting( + ([user_shifted], [assistant_placeholder]), + 's2s_eos' + ) + + attention_mask = (input_ids != uni_prompting.pad_id).long() + + try: + generated_sequences = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=256, + steps=256, + block_length=256, + temperature=config.training.get("s2s_generation_temperature", 1.0), + cfg_scale=config.training.get("s2s_guidance_scale", 2.5), + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=audio_codebook_size, + ) + except Exception as exc: + logger.error(f"Generation failed for '{audio_path}': {exc}") + continue + + if not generated_sequences: + logger.warning(f"No tokens generated for '{audio_path}'. Skipping sample.") + continue + + gen_tokens = generated_sequences[0] + if isinstance(gen_tokens, torch.Tensor): + gen_tokens = gen_tokens.detach().cpu() + token_list = gen_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list empty for '{audio_path}'. Skipping sample.") + continue + + speech_unit_str = "".join([f"<|speech_{int(token)}|>" for token in token_list]) + output_path = output_root / f"{audio_path.stem}_reply.wav" + + try: + vq_model_audio.decode(speech_unit_str, condition=condition, output_wav_file=str(output_path)) + except Exception as exc: + logger.error(f"Decoding failed for '{audio_path}': {exc}") + continue + + table.add_data( + audio_path.name, + wandb.Audio(str(audio_path), caption="source"), + wandb.Audio(str(output_path), caption="generated"), + len(token_list), + ) + + row_count = getattr(table, "num_rows", None) + if row_count is None: + table_data = getattr(table, "data", None) + row_count = len(table_data) if table_data is not None else 0 + + if row_count > 0: + accelerator.log({"eval/s2s_samples": table}, step=global_step) + else: + logger.warning("S2S evaluation produced no loggable samples.") + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ TEXT EVALUATION LOGIC ++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_text(model, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + + logger.info("***** Running Text Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + dataset_cfg = getattr(config.dataset, "params", {}) + prompts_file = getattr(dataset_cfg, "text_eval_prompts_file", "MMaDA/validation_prompts/math.txt") + + prompts_path = Path(prompts_file) + if not prompts_path.is_absolute(): + prompts_path = Path.cwd() / prompts_path + if not prompts_path.exists(): + repo_root = Path(__file__).resolve().parents[2] + alt_path = repo_root / prompts_file + if alt_path.exists(): + prompts_path = alt_path + + if not prompts_path.exists(): + logger.warning(f"Text evaluation prompts file '{prompts_file}' not found. Skipping text evaluation.") + return + + try: + with open(prompts_path, "r", encoding="utf-8") as handle: + raw_prompts = [line.strip() for line in handle if line.strip()] + except OSError as exc: + logger.warning(f"Failed to read text evaluation prompts from '{prompts_path}': {exc}. Skipping text evaluation.") + return + + if not raw_prompts: + logger.warning("Text evaluation prompt list is empty. Skipping text evaluation.") + return + + max_samples = getattr(config.experiment, "eval_num_text_samples", 4) + if not isinstance(max_samples, int) or max_samples <= 0: + max_samples = 4 + questions = raw_prompts[:max_samples] + + chat_prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for question in questions + ] + + tokenizer = uni_prompting.text_tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + answers: list[str] = [] + + for chat_prompt in chat_prompts: + tokens = tokenizer( + chat_prompt, + return_tensors="pt", + padding=True, + truncation=True, + ) + + input_ids = tokens["input_ids"].to(accelerator.device) + out = generate(unwrapped_model, input_ids, steps=128, gen_length=128, block_length=128, temperature=1, cfg_scale=0., remasking='low_confidence') + answer = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True) + + answers.append(answer) + + table = wandb.Table(columns=["Index", "Question", "Answer"]) + for idx, (question, answer) in enumerate(zip(questions, answers)): + table.add_data(idx, question, answer) + + accelerator.log({"eval/text_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(8)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation"): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + eos_token = uni_prompting.text_tokenizer.eos_token + eos_marker = eos_token if eos_token is not None else "" + for i in range(len(decoded_texts)): + full_text = decoded_texts[i] + eos_idx = full_text.find(eos_marker) + cleaned_text = full_text[:eos_idx] if eos_idx != -1 else full_text + cleaned_text = cleaned_text.replace(eos_marker, "").strip() + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": cleaned_text, + }) + + if not local_results: + logger.warning("S2T evaluation produced no results.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["decoded_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + + samples_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for idx, res in enumerate(local_results): + sample_id = res.get("sample_id", idx) + samples_table.add_data(sample_id, res["gt_text"], res["decoded_text"]) + + accelerator.log({"eval/s2t_samples": samples_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(8)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for T2S: user prompt + text + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 384 - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=5.0, + temperature=1.0, + timesteps=50, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=383, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + if not local_results: + logger.warning("Skipping T2S evaluation logging because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """Text-to-speech evaluation using the MMU-style block refinement decoder.""" + + if not accelerator.is_main_process: + return + + logger.info("***** Running T2S Evaluation (MMU-style decoder) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(8)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as exc: + logger.error(f"Failed to load T2S dataset or Whisper model for MMU-style eval: {exc}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_mmu") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + + codebook_size = config.model.omada.codebook_size + speech_vocab_size = 4096 + + for batch in tqdm(t2s_dataloader, desc="T2S MMU Eval"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + + batch_size = len(prompts) + speech_token_length = 384 - 1 + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + output_ids = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=speech_token_length, + steps=384 - 1, + block_length=384 - 1, + temperature=1.0, + cfg_scale=3.0, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=speech_vocab_size, + ) + + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + if isinstance(gen_speech_tokens, torch.Tensor): + gen_speech_tokens = gen_speech_tokens.detach().cpu() + + token_list = gen_speech_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]} (MMU eval). Skipping.") + continue + + speech_unit_str = " ".join(map(str, token_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}_mmu.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + try: + vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + except Exception as exc: + logger.error(f"Decoding failed for sample {sample_ids[i]} (MMU eval): {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": output_wav_path, + }) + + if not local_results: + logger.warning("Skipping T2S MMU-style evaluation because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S (MMU-style) Final WER: {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_mmu_like_wer": wer, + "eval/t2s_mmu_like_word_errors": errors, + "eval/t2s_mmu_like_total_words": words, + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_mmu_like_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Text-to-Speech (fixed-length) evaluation: + - Input prompt contains SOA + [MASK]*L + EOA (EOA is injected, not predicted) + - The model only fills VQ codes for exactly L positions (no EOA/EOS prediction) + - Generated audio is transcribed by Whisper; we report WER + """ + if not accelerator.is_main_process: + return + logger.info("***** Running T2S (fixed-length) Evaluation *****") + unwrapped = accelerator.unwrap_model(model) + unwrapped.eval() + + # Load eval dataset and Whisper model + try: + ds_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3", + device=accelerator.device + ) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load dataset or Whisper model: {e}") + return + + # Directory for saving generated audio files of this evaluation step + out_dir = os.path.join( + "/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_fixed" + ) + os.makedirs(out_dir, exist_ok=True) + + eval_ds = T2SEvalDataset(ds_raw) + loader = DataLoader(eval_ds, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + for batch in tqdm(loader, desc="T2S Fixed Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for fixed-length T2S + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 256 - 2 # exclude and + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_fixed_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + # Core generation call: + # - predict_eoa=False prevents EOA/EOS prediction; only VQ codes are generated + outputs = unwrapped.t2s_fixed_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=1.5, + temperature=1.0, + timesteps=24, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=speech_token_length, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode generated VQ codes → waveform via the speech tokenizer, then ASR with Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + gen_rel = outputs[i] # relative VQ ids in [0..4095] + id_list = gen_rel.tolist() + + if not id_list: + logger.warning(f"[fixed] Empty tokens for {sample_ids[i]}; skipping.") + continue + + # Convert to the speech-unit string format expected by the decoder + unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")]) + + # Synthesize audio and run Whisper + fname = f"process_{accelerator.process_index}_{sample_ids[i]}_fixed.wav" + wav_path = os.path.join(out_dir, fname) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + _ = vq_model_audio.decode( + speech_unit_for_decode, + condition=condition, + output_wav_file=wav_path + ) + asr = whisper_pipe(wav_path, generate_kwargs={"language": "english"}) + whisper_text = asr.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": wav_path + }) + + if not local_results: + logger.warning("Skipping T2S fixed evaluation logging because no samples were generated.") + return + + gt_list = [r["gt_text"] for r in local_results] + pred_list = [r["whisper_text"] for r in local_results] + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Fixed WER: {wer:.4f} | Errors: {errors} | Words: {words}") + + accelerator.log({ + "eval/t2s_fixed_wer": wer, + "eval/t2s_fixed_errors": errors, + "eval/t2s_fixed_words": words + }, step=global_step) + + table = wandb.Table(columns=["ID", "GT", "ASR", "Audio"]) + for r in local_results[:8]: + table.add_data( + r["sample_id"], + r["gt_text"], + r["whisper_text"], + wandb.Audio(r["audio_path"], caption=r["whisper_text"]) + ) + accelerator.log({"eval/t2s_fixed_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + if accelerator.is_main_process: + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + + if accelerator.is_main_process: + evaluate_text(model, uni_prompting, config, accelerator, global_step) + evaluate_t2i(model, vq_model_image, uni_prompting, config, accelerator, global_step) + evaluate_i2i(model, vq_model_image, uni_prompting, config, accelerator, global_step) + evaluate_s2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = ( + config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s + + config.training.batch_size_s2s + ) - 1 # -1 since t2s/ s2t choice + + total_batch_size = ( + total_batch_size_per_gpu + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, max_audio_len_s2s=config.dataset.preprocessing.max_aud_length_s2s, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|i2i|>", "<|v2t|>", "<|v2s|>", "<|s2t|>", "<|t2s|>", "<|s2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16, config='/home/work/AIDAS/ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model/config.json').to(accelerator.device) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + def build_distributed_sampler(dataset, *, shuffle=True, drop_last=True): + """Create a DistributedSampler only when running with multiple processes.""" + if dataset is None or accelerator.num_processes <= 1: + return None + return DistributedSampler( + dataset, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=shuffle, + drop_last=drop_last, + ) + + batch_size_t2i_cfg = config.training.batch_size_t2i + batch_size_lm_cfg = config.training.batch_size_lm + batch_size_mmu_cfg = config.training.batch_size_mmu + batch_size_t2s_cfg = config.training.batch_size_t2s + batch_size_s2t_cfg = config.training.batch_size_s2t + batch_size_v2t_cfg = config.training.batch_size_v2t + batch_size_s2s_cfg = config.training.batch_size_s2s + + total_batch_size = ( + total_batch_size_per_gpu + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Text-to-image dataset + logger.info("Loading Text-to-image dataset") + dataset_t2i = None + train_dataloader_t2i = None + if batch_size_t2i_cfg > 0: + dataset_t2i = HQEditX2IDataset(split=dataset_config.get("hqedit_split", "train"), resolution=dataset_config.resolution) + + sampler_t2i = build_distributed_sampler( + dataset_t2i, + shuffle=True, + drop_last=True, + ) + + train_dataloader_t2i = DataLoader( + dataset_t2i, + batch_size=batch_size_t2i_cfg, + sampler=sampler_t2i, + shuffle=sampler_t2i is None, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_x2i, + drop_last=True, + ) + + # Language modeling dataset (HF instruction mixture) + logger.info("Loading LM dataset") + dataset_lm = None + train_dataloader_lm = None + if batch_size_lm_cfg > 0: + instruction_cfg = getattr(dataset_config, "hf_instruction_lm", {}) + if not isinstance(instruction_cfg, dict): + instruction_cfg = OmegaConf.to_container(instruction_cfg, resolve=True) + instruction_cfg = instruction_cfg or {} + + seed_lm = instruction_cfg.get("seed") + if seed_lm is None: + seed_lm = getattr(config.training, "seed", 42) or 42 + + dataset_lm = HFInstructionTextDataset( + split=instruction_cfg.get("split", "train"), + max_samples_per_source=instruction_cfg.get("max_samples_per_source"), + max_total_samples=instruction_cfg.get("max_total_samples"), + seed=int(seed_lm), + ) + + pin_memory = bool(getattr(dataset_config, "pin_memory", False)) + persistent_workers = bool(getattr(dataset_config, "persistent_workers", False)) + + sampler_lm = build_distributed_sampler( + dataset_lm, + shuffle=True, + drop_last=True, + ) + + train_dataloader_lm = DataLoader( + dataset_lm, + batch_size=batch_size_lm_cfg, + sampler=sampler_lm, + shuffle=sampler_lm is None, + collate_fn=dataset_lm.collate_fn, + num_workers=dataset_config.num_workers, + drop_last=True, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + # Video Dataset + logger.info("Loading Video dataset") + video_captioning_dataset = None + train_dataloader_v2t = None + sampler_v2t = None + if batch_size_v2t_cfg > 0: + video_captioning_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + num_frames=8, + ) + + sampler_v2t = build_distributed_sampler( + video_captioning_dataset, + shuffle=True, + drop_last=True, + ) + + train_dataloader_v2t = DataLoader( + video_captioning_dataset, + batch_size=batch_size_v2t_cfg, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_video_caption, + sampler=sampler_v2t, + shuffle=sampler_v2t is None, + drop_last=True, + ) + + # Speech Dataset + logger.info("Loading Speech dataset") + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + # Speech-to-Speech Dataset (EMOVA + Instruct S2S) + dataset_s2s = None + sampler_s2s = None + train_dataloader_s2s = None + if config.training.batch_size_s2s > 0: + dataset_s2s = Speech2SpeechDataset(dataset_config.get("speech2speech", [])) + + # Multi-image interleaved dataset (MMU-style) + logger.info("Loading MMU dataset") + dataset_mmu = None + sampler_mmu = None + train_dataloader_mmu = None + if config.training.batch_size_mmu > 0: + mmu_params = dataset_config.get("mmu_interleaved", {}) + if mmu_params is None: + mmu_kwargs = {} + elif isinstance(mmu_params, dict): + mmu_kwargs = mmu_params + else: + mmu_kwargs = OmegaConf.to_container(mmu_params, resolve=True) + dataset_mmu = TextImageInterleavedDataset(**mmu_kwargs) + + logger.info("Dataset Prepared.") + + # Use distinct DistributedSamplers for each speech dataloader to avoid iterator interference + if accelerator.num_processes > 1: + sampler_s2t = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + sampler_t2s = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + if dataset_s2s is not None: + sampler_s2s = DistributedSampler( + dataset_s2s, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + if dataset_mmu is not None: + sampler_mmu = DistributedSampler( + dataset_mmu, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + else: + sampler_s2t = None + sampler_t2s = None + sampler_s2s = None + sampler_mmu = None + + train_dataloader_s2t = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_s2t, + shuffle=False, + sampler=sampler_s2t, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + train_dataloader_t2s = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_t2s, + shuffle=False, + sampler=sampler_t2s, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + + if dataset_s2s is not None: + train_dataloader_s2s = DataLoader( + dataset_s2s, + batch_size=batch_size_s2s_cfg, + shuffle=False, + sampler=sampler_s2s, + collate_fn=s2s_collate_fn, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + + if dataset_mmu is not None: + train_dataloader_mmu = DataLoader( + dataset_mmu, + batch_size=config.training.batch_size_mmu, + shuffle=False, + sampler=sampler_mmu, + collate_fn=collate_fn_mmu_mult, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + + # Combine these dataloaders into a single iterable model + iterables = {} + if train_dataloader_v2t is not None: + iterables["v2t_flow"] = train_dataloader_v2t + iterables["t2s_flow"] = train_dataloader_t2s + iterables["s2t_flow"] = train_dataloader_s2t + if train_dataloader_t2i is not None: + iterables["x2i_flow"] = train_dataloader_t2i + if train_dataloader_lm is not None: + iterables["lm_flow"] = train_dataloader_lm + if train_dataloader_mmu is not None: + iterables["mmu_flow"] = train_dataloader_mmu + if train_dataloader_s2s is not None: + iterables["s2s_flow"] = train_dataloader_s2s + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + # image generation + total_batch_size_t2i = config.training.batch_size_t2i * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_t2i = math.ceil(len(dataset_t2i) / total_batch_size_t2i) + + # lm + total_batch_size_lm = config.training.batch_size_lm * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_lm = math.ceil(len(dataset_lm) / total_batch_size_lm) + + # s2t + total_batch_size_s2t = config.training.batch_size_s2t * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_s2t = math.ceil(len(dataset_sm) / total_batch_size_s2t) + + # t2s + total_batch_size_t2s = config.training.batch_size_t2s * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_t2s = math.ceil(len(dataset_sm) / total_batch_size_t2s) + + # s2s + total_batch_size_s2s = config.training.batch_size_s2s * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_s2s = math.ceil(len(dataset_s2s) / total_batch_size_s2s) + + # v2t + total_batch_size_v2t = config.training.batch_size_v2t * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_v2t = math.ceil(len(video_captioning_dataset) / total_batch_size_v2t) + + # mmu + total_batch_size_mmu = config.training.batch_size_mmu * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_mmu = math.ceil(len(dataset_mmu) / total_batch_size_mmu) + + # Calculate num_train_epochs + num_update_steps_per_epoch = max( + num_update_steps_per_epoch_t2i, + num_update_steps_per_epoch_lm, + num_update_steps_per_epoch_s2t, + num_update_steps_per_epoch_t2s, + num_update_steps_per_epoch_v2t, + num_update_steps_per_epoch_s2s, + num_update_steps_per_epoch_mmu, + ) + + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + + logger.info(f"len of T2I: {len(dataset_t2i)}") + logger.info(f"len of LM: {len(dataset_lm)}") + logger.info(f"len of Speech: {len(dataset_sm)}") + logger.info(f"len of Video: {len(video_captioning_dataset)}") + logger.info(f"len of S2S: {len(dataset_s2s)}") + logger.info(f"len of MMU: {len(dataset_mmu)}") + + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def safe_audio_encode(audio_path: str, flow_name: str): + try: + tokens = vq_model_audio.encode(audio_path) + return tokens, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + def safe_video_get_code(video_tensor_sample: torch.Tensor, sample_index: int): + try: + video_token = vq_model_image.get_code(video_tensor_sample) + return video_token, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] v2t video encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + def safe_image_get_code(image_tensor_sample: torch.Tensor, sample_index: int): + try: + if image_tensor_sample.dim() == 3: + image_tensor_sample = image_tensor_sample.unsqueeze(0) + elif image_tensor_sample.dim() != 4: + raise ValueError( + f"Expected image tensor with 3 or 4 dims, got shape {tuple(image_tensor_sample.shape)}" + ) + image_token = vq_model_image.get_code(image_tensor_sample) + return image_token, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] s2s image encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + def _decode_single_image(single_like): + if single_like is None: + return None + if isinstance(single_like, Image.Image): + return single_like.convert('RGB') + + data_bytes = None + + if isinstance(single_like, (bytes, bytearray)): + data_bytes = bytes(single_like) + elif isinstance(single_like, str): + try: + data_bytes = base64.b64decode(single_like) + except (binascii.Error, ValueError): + if os.path.isfile(single_like): + try: + with open(single_like, 'rb') as fh: + data_bytes = fh.read() + except OSError: + data_bytes = None + elif isinstance(single_like, dict): + binary_payload = single_like.get('bytes') + if binary_payload is not None: + data_bytes = binary_payload + else: + path_value = single_like.get('path') + if path_value and os.path.isfile(path_value): + try: + with open(path_value, 'rb') as fh: + data_bytes = fh.read() + except OSError: + data_bytes = None + + if data_bytes is None: + return None + + try: + with Image.open(BytesIO(data_bytes)) as img: + return img.convert('RGB') + except Exception: + return None + + def maybe_decode_image(image_like): + if isinstance(image_like, (list, tuple)): + return [_decode_single_image(item) for item in image_like] + return _decode_single_image(image_like) + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_i2i( + source_images: torch.FloatTensor, + target_images: torch.FloatTensor, + prompts: list[str], + is_train: bool = True, + ): + """Build masked i2i sequences from source/target image pairs.""" + + # Tokenize source/target images with VQ model and offset by text vocab size + source_tokens = vq_model_image.get_code(source_images) + len(uni_prompting.text_tokenizer) + target_tokens = vq_model_image.get_code(target_images) + len(uni_prompting.text_tokenizer) + + cond_dropout_prob = config.training.get( + "i2i_cond_dropout_prob", + config.training.cond_dropout_prob, + ) + + if is_train and torch.rand(1, device=source_tokens.device).item() < cond_dropout_prob: + effective_prompts = [''] * len(prompts) + masked_target_source = source_tokens + else: + effective_prompts = list(prompts) + masked_target_source = target_tokens + + masked_target_tokens, labels, _, mask_prob = mask_or_random_replace_tokens( + masked_target_source, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + + input_ids, attention_masks, labels = uni_prompting( + (effective_prompts, source_tokens, masked_target_tokens, labels), + 'i2i' + ) + + return input_ids, labels, mask_prob, attention_masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2s( + input_ids_s2s, prompt_masks, labels_s2s, eps=1e-3 + ): + b, l = input_ids_s2s.shape + t = torch.rand(b, device=input_ids_s2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_s2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_s2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_s2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_s2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + # Ensure all samplers reshuffle in a rank-consistent way each epoch + try: + if isinstance(sampler_v2t, DistributedSampler): + sampler_v2t.set_epoch(epoch) + if accelerator.num_processes > 1: + if sampler_s2t is not None: + sampler_s2t.set_epoch(epoch) + if sampler_t2s is not None: + sampler_t2s.set_epoch(epoch) + if sampler_s2s is not None: + sampler_s2s.set_epoch(epoch) + except Exception: + pass + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + # Synchronize skip decision across all ranks to avoid collective mismatches + required_flows = ["v2t_flow", "t2s_flow", "s2t_flow"] + if train_dataloader_t2i is not None: + required_flows.append("x2i_flow") + if train_dataloader_lm is not None: + required_flows.append("lm_flow") + if train_dataloader_mmu is not None: + required_flows.append("mmu_flow") + if train_dataloader_s2s is not None: + required_flows.append("s2s_flow") + + local_skip = 0 + for key in required_flows: + if batch is None or batch.get(key) is None: + local_skip = 1 + break + try: + skip_tensor = torch.tensor(local_skip, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction='sum') + should_skip = skip_sum.item() > 0 + except Exception: + # Fallback if reduce isn't available for any reason + should_skip = local_skip == 1 + + if should_skip: + if accelerator.is_main_process and local_skip: + logger.warning(f"Skipping step {global_step} (batch is None or v2t_flow missing) [synced]") + continue + + device = accelerator.device + + # Text-to-image samples + batch_size_t2i = 0 + mask_prob = torch.tensor(0.0, device=device) + t2i_masks = torch.empty((0, 1), dtype=torch.long, device=device) + input_ids_t2i = torch.empty((0, 1), dtype=torch.long, device=device) + labels_t2i = torch.empty((0, 1), dtype=torch.long, device=device) + batch_size_i2i = 0 + mask_prob_i2i = torch.tensor(0.0, device=device) + input_ids_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + labels_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + attention_masks_i2i = torch.empty((0, 1), dtype=torch.long, device=device) + + if train_dataloader_t2i is not None: + x2i_batch = batch.get("x2i_flow") + if x2i_batch is not None: + # T2I branch + t2i_texts = x2i_batch["t2i"].get("texts", []) + t2i_images_tensor = x2i_batch["t2i"].get("images") + if isinstance(t2i_images_tensor, torch.Tensor) and t2i_images_tensor.shape[0] > 0: + t2i_images_tensor = t2i_images_tensor.to(device, non_blocking=True) + batch_size_t2i = t2i_images_tensor.shape[0] + ( + input_ids_t2i, + labels_t2i, + mask_prob, + _, + t2i_masks, + ) = prepare_inputs_and_labels(t2i_images_tensor, t2i_texts, config.training.min_masking_rate) + input_ids_t2i = input_ids_t2i.to(device, non_blocking=True) + labels_t2i = labels_t2i.to(device, non_blocking=True) + t2i_masks = t2i_masks.to(device, non_blocking=True) + if mask_prob.device != device: + mask_prob = mask_prob.to(device) + # I2I branch + i2i_prompts = x2i_batch["i2i"].get("prompts", []) + i2i_source_tensor = x2i_batch["i2i"].get("source_images") + i2i_target_tensor = x2i_batch["i2i"].get("target_images") + if ( + isinstance(i2i_source_tensor, torch.Tensor) + and isinstance(i2i_target_tensor, torch.Tensor) + and i2i_source_tensor.shape[0] > 0 + and i2i_target_tensor.shape[0] > 0 + ): + i2i_source_tensor = i2i_source_tensor.to(device, non_blocking=True) + i2i_target_tensor = i2i_target_tensor.to(device, non_blocking=True) + batch_size_i2i = i2i_source_tensor.shape[0] + ( + input_ids_i2i, + labels_i2i, + mask_prob_i2i, + attention_masks_i2i, + ) = prepare_inputs_and_labels_for_i2i( + i2i_source_tensor, + i2i_target_tensor, + i2i_prompts, + is_train=True, + ) + input_ids_i2i = input_ids_i2i.to(device, non_blocking=True) + labels_i2i = labels_i2i.to(device, non_blocking=True) + attention_masks_i2i = attention_masks_i2i.to(device, non_blocking=True) + if mask_prob_i2i.device != device: + mask_prob_i2i = mask_prob_i2i.to(device) + + # Language modeling samples + batch_size_lm = 0 + input_ids_lm = torch.empty((0, 1), dtype=torch.long, device=device) + labels_lm = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_lm = torch.empty((0, 1), dtype=torch.float32, device=device) + if train_dataloader_lm is not None: + lm_batch = batch.get("lm_flow") + if lm_batch is not None: + texts_lm = lm_batch["input_ids"] + batch_size_lm = len(texts_lm) + max_seq_for_lm = input_ids_t2i.shape[1] if batch_size_t2i > 0 else preproc_config.max_seq_length + input_ids_lm, labels_lm, p_mask_lm = prepare_inputs_and_labels_for_text(texts_lm, max_seq_for_lm) + input_ids_lm = input_ids_lm.to(device, non_blocking=True) + labels_lm = labels_lm.to(device, non_blocking=True) + p_mask_lm = p_mask_lm.to(device, non_blocking=True) + + v2t_batch = batch.get("v2t_flow") + batch_size_v2t = len(v2t_batch["video"]) if v2t_batch is not None else 0 + batch_size_t2s_text = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + + s2s_batch = batch.get("s2s_flow") + batch_size_s2s = 0 + if s2s_batch is not None: + batch_size_s2s = len(s2s_batch.get("emova_sft", [])) + len(s2s_batch.get("instructs2s", [])) + + mmu_batch = batch.get("mmu_flow") + batch_size_mmu = 0 + image_tensor_list = [] + texts_image = [] + if mmu_batch is not None: + image_tensor_list = mmu_batch.get("images", []) + texts_image = mmu_batch.get("text", []) + batch_size_mmu = len(image_tensor_list) + + # print(f"Rank {accelerator.process_index} loading data...") + # print(batch["s2t_flow"]["audio_path"]) + # print(batch["v2t_flow"]['captions']) + + s2t_flow = batch.get("s2t_flow", {}) + t2s_flow = batch.get("t2s_flow", {}) + audio_paths_s2t, texts_s2t = s2t_flow.get("audio_path", []), s2t_flow.get("text", []) + audio_paths_t2s, texts_t2s = t2s_flow.get("audio_path", []), t2s_flow.get("text", []) + + # Randomly drop one of the S2T/T2S tasks to ease memory pressure when both are present + if batch_size_s2t > 0 and batch_size_t2s_text > 0: + if random.random() < 0.5: + # keep S2T, drop T2S + audio_paths_t2s = [] + texts_t2s = [] + batch_size_t2s_text = 0 + else: + # keep T2S, drop S2T + audio_paths_s2t = [] + texts_s2t = [] + batch_size_s2t = 0 + else: + batch_size_s2t = len(audio_paths_s2t) + batch_size_t2s_text = len(audio_paths_t2s) + + logger.info( + f"batch_size_t2i: {batch_size_t2i}, batch_size_i2i: {batch_size_i2i}, batch_size_lm: {batch_size_lm}, " + f"batch_size_v2t: {batch_size_v2t}, batch_size_t2s: {batch_size_t2s_text}, " + f"batch_size_s2t: {batch_size_s2t}, batch_size_s2s: {batch_size_s2s}, batch_size_mmu: {batch_size_mmu}" + ) + offset = speech_vocab_start + if v2t_batch is not None: + video_tensor, texts_vid = v2t_batch["video"], v2t_batch["captions"] + else: + video_tensor = torch.empty((0, 1, 1, 1, 1), device=accelerator.device) + texts_vid = [] + + data_time_m.update(time.time() - end) + + failure_messages = [] + step_failed = False + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for video understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if batch_size_v2t > 0 and not step_failed: + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + for vid_idx, video in enumerate(video_tensor): # each video is (T, C, H, W) + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + video_token = tokens + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + if not step_failed and video_token_list: + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + input_ids_vid, prompt_masks_vid, labels_vid = uni_prompting((video_tokens, texts_vid), 'v2t') + # Keep trailing EOS tokens so v2t learns to emit explicit padding. + + ( + input_ids_vid, + labels_vid, + p_mask_vid, + answer_lengths_vid + ) = prepare_inputs_and_labels_for_mmu(input_ids_vid, prompt_masks_vid, labels_vid) + + input_ids_vid = input_ids_vid.to(accelerator.device, non_blocking=True) + else: + input_ids_vid = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_vid = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_vid = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_vid = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + else: + input_ids_vid = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_vid = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_vid = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_vid = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed and batch_size_s2t > 0: + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_s2t: + tokens, err = safe_audio_encode(path, "s2t") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + # Preserve trailing EOS tokens in s2t targets for explicit prediction. + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + else: + input_ids_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_s2t = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_s2t = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed and batch_size_t2s_text > 0: + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_t2s: + tokens, err = safe_audio_encode(path, "t2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + # Chat-style instruction formatting for T2S training + prompt = random.choice(prompt_t2s) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in texts_t2s + ] + + # input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + else: + input_ids_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_t2s = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_t2s = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + + audio_user_ids_s2s: list[torch.Tensor] = [] + audio_asst_ids_s2s: list[torch.Tensor] = [] + image_token_blocks_s2s: list[Optional[torch.Tensor]] = [] + input_ids_s2s = None + labels_s2s = None + p_mask_s2s = None + answer_lengths_s2s = None + + if not step_failed and batch_size_s2s > 0 and s2s_batch is not None: + s2s_sample_counter = 0 + + emova_samples = s2s_batch.get("emova_sft", []) + for sample_idx, (usr_ids, asst_ids, image_like) in enumerate(emova_samples): + usr_tensor = torch.tensor(usr_ids, dtype=torch.long, device=accelerator.device).unsqueeze(0) + asst_tensor = torch.tensor(asst_ids, dtype=torch.long, device=accelerator.device).unsqueeze(0) + + audio_user_ids_s2s.append(usr_tensor + offset) + audio_asst_ids_s2s.append(asst_tensor + offset) + + decoded_payload = maybe_decode_image(image_like) + if isinstance(decoded_payload, (list, tuple)): + token_payloads = [] + for local_img in decoded_payload: + if local_img is None: + continue + pixel_values = image_transform(local_img, resolution=preproc_config.resolution).to(accelerator.device) + if pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + image_tokens_raw, err = safe_image_get_code(pixel_values, s2s_sample_counter) + if err is not None: + failure_messages.append(err) + step_failed = True + break + image_tokens = image_tokens_raw + len(uni_prompting.text_tokenizer) + token_payloads.append(image_tokens.view(1, -1).to(accelerator.device, non_blocking=True)) + s2s_sample_counter += 1 + if step_failed: + break + image_token_blocks_s2s.append(token_payloads if token_payloads else None) + else: + if decoded_payload is not None: + pixel_values = image_transform(decoded_payload, resolution=preproc_config.resolution).to(accelerator.device) + if pixel_values.dim() == 3: + pixel_values = pixel_values.unsqueeze(0) + image_tokens_raw, err = safe_image_get_code(pixel_values, s2s_sample_counter) + if err is not None: + failure_messages.append(err) + step_failed = True + break + image_tokens = image_tokens_raw + len(uni_prompting.text_tokenizer) + image_token_blocks_s2s.append(image_tokens.view(1, -1).to(accelerator.device, non_blocking=True)) + s2s_sample_counter += 1 + else: + image_token_blocks_s2s.append(None) + + instruct_samples = [] if step_failed else s2s_batch.get("instructs2s", []) + if not step_failed: + for sample_idx, (usr_wav, asst_wav, _) in enumerate(instruct_samples): + user_tokens, err_usr = safe_audio_encode(usr_wav, "s2s-user") + if err_usr is not None: + failure_messages.append(err_usr) + step_failed = True + break + asst_tokens, err_asst = safe_audio_encode(asst_wav, "s2s-assistant") + if err_asst is not None: + failure_messages.append(err_asst) + step_failed = True + break + + user_tokens = user_tokens.to(accelerator.device, non_blocking=True) + asst_tokens = asst_tokens.to(accelerator.device, non_blocking=True) + if user_tokens.dim() == 1: + user_tokens = user_tokens.unsqueeze(0) + if asst_tokens.dim() == 1: + asst_tokens = asst_tokens.unsqueeze(0) + + audio_user_ids_s2s.append(user_tokens + offset) + audio_asst_ids_s2s.append(asst_tokens + offset) + image_token_blocks_s2s.append(None) + s2s_sample_counter += 1 + + if not step_failed and audio_user_ids_s2s: + input_ids_s2s, prompt_masks_s2s, labels_s2s = uni_prompting( + (audio_user_ids_s2s, audio_asst_ids_s2s, image_token_blocks_s2s), + 's2s_eos' + ) + + print(input_ids_s2s) + + input_ids_s2s, labels_s2s, p_mask_s2s, answer_lengths_s2s = prepare_inputs_and_labels_for_s2s( + input_ids_s2s, + prompt_masks_s2s, + labels_s2s, + ) + + if input_ids_s2s is None: + device = accelerator.device + input_ids_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + labels_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + p_mask_s2s = torch.empty((0, 1), dtype=torch.float32, device=device) + answer_lengths_s2s = torch.empty((0, 1), dtype=torch.long, device=device) + + input_ids_mmu = None + labels_mmu = None + p_mask_mmu = None + answer_lengths_mmu = None + + if not step_failed and batch_size_mmu > 0: + batch_image_ids_list = [] + batch_text_ids = [] + + for b_idx, image_list in enumerate(image_tensor_list): + per_img_ids = [] + for j, img in enumerate(image_list): + tok, err = safe_image_get_code( + img.to(accelerator.device, non_blocking=True), + sample_index=j + ) + if err is not None: + failure_messages.append(err) + step_failed = True + break + + tok = tok.to(accelerator.device, non_blocking=True).view(-1).long() + tok = tok + len(uni_prompting.text_tokenizer) + per_img_ids.append(tok) + + if step_failed: + break + + batch_image_ids_list.append(per_img_ids) + text_ids = uni_prompting.text_tokenizer.encode(texts_image[b_idx], add_special_tokens=False) + batch_text_ids.append(text_ids) + + if not step_failed: + input_ids_mmu, prompt_masks_mmu, labels_mmu = uni_prompting.mmu_mult_prompt( + batch_image_ids_list=batch_image_ids_list, + batch_text_ids=batch_text_ids, + ) + + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths_mmu + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks_mmu, labels_mmu) + + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + labels_mmu = labels_mmu.to(accelerator.device, non_blocking=True) + p_mask_mmu = p_mask_mmu.to(accelerator.device, non_blocking=True) + answer_lengths_mmu = answer_lengths_mmu.to(accelerator.device, non_blocking=True) + + if batch_size_mmu == 0 or input_ids_mmu is None: + input_ids_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + labels_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + p_mask_mmu = torch.empty((0, 1), dtype=torch.float32, device=accelerator.device) + answer_lengths_mmu = torch.empty((0, 1), dtype=torch.long, device=accelerator.device) + if not step_failed: + total_batch_size_t2s = batch_size_t2s_text + else: + total_batch_size_t2s = batch_size_t2s_text + + failure_tensor = torch.tensor(1 if step_failed else 0, device=accelerator.device, dtype=torch.int32) + failure_sum = accelerator.reduce(failure_tensor, reduction='sum') + if failure_sum.item() > 0: + if accelerator.is_main_process and failure_messages: + for msg in failure_messages: + logger.warning(f"Skipping global step {global_step} due to preprocessing failure: {msg}") + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + # -------------------------------------------------------------------------------- + # for name, tensor in [ + # ("t2i", input_ids_t2i), + # ("i2i", input_ids_i2i), + # ("lm", input_ids_lm), + # ("mmu", input_ids_mmu), + # ("vid", input_ids_vid), + # ("s2t", input_ids_s2t), + # ("s2s", input_ids_s2s), + # ("t2s", input_ids_t2s), + # ]: + # if tensor is not None: + # print(f"{name:>4}: shape={getattr(tensor, 'shape', None)}, len={len(tensor) if hasattr(tensor, '__len__') else 'N/A'}") + + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + seq_lengths = [] + if input_ids_t2i.shape[0] > 0: + seq_lengths.append(input_ids_t2i.shape[1]) + if input_ids_i2i.shape[0] > 0: + seq_lengths.append(input_ids_i2i.shape[1]) + if input_ids_lm.shape[0] > 0: + seq_lengths.append(input_ids_lm.shape[1]) + seq_lengths.extend([ + input_ids_vid.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1], + ]) + if input_ids_s2s.shape[0] > 0: + seq_lengths.append(input_ids_s2s.shape[1]) + if input_ids_mmu.shape[0] > 0: + seq_lengths.append(input_ids_mmu.shape[1]) + max_len = max(seq_lengths) + + # 3. Pad all tensors to the max_len + input_ids_t2i = pad_tensor(input_ids_t2i, max_len, pad_token_id) + labels_t2i = pad_tensor(labels_t2i, max_len, -100) + if t2i_masks.shape[0] > 0: + t2i_masks = pad_tensor(t2i_masks.long(), max_len, 0) + else: + t2i_masks = torch.empty((0, max_len), dtype=torch.long, device=device) + + input_ids_i2i = pad_tensor(input_ids_i2i, max_len, pad_token_id) + labels_i2i = pad_tensor(labels_i2i, max_len, -100) + if attention_masks_i2i.shape[0] > 0: + attention_masks_i2i = pad_tensor(attention_masks_i2i.long(), max_len, 0) + else: + attention_masks_i2i = torch.empty((0, max_len), dtype=torch.long, device=device) + + + input_ids_lm = pad_tensor(input_ids_lm, max_len, pad_token_id) + labels_lm = pad_tensor(labels_lm, max_len, -100) + p_mask_lm = pad_tensor(p_mask_lm, max_len, 1.0) + + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + input_ids_s2s = pad_tensor(input_ids_s2s, max_len, pad_token_id) + input_ids_mmu = pad_tensor(input_ids_mmu, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + labels_s2s = pad_tensor(labels_s2s, max_len, -100) + labels_mmu = pad_tensor(labels_mmu, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + p_mask_s2s = pad_tensor(p_mask_s2s, max_len, 1.0) + p_mask_mmu = pad_tensor(p_mask_mmu, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + answer_lengths_s2s = pad_answer_lengths(answer_lengths_s2s, max_len) + answer_lengths_mmu = pad_answer_lengths(answer_lengths_mmu, max_len) + + input_ids = torch.cat(( + input_ids_t2i, + input_ids_i2i, + input_ids_lm, + input_ids_mmu, + input_ids_vid, + input_ids_s2t, + input_ids_s2s, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_t2i, + labels_i2i, + labels_lm, + labels_mmu, + labels_vid, + labels_s2t, + labels_s2s, + labels_t2s + ), dim=0) + + # w/o texts and images + if batch_size_lm == 0: + p_mask_lm = torch.empty((0, max_len), dtype=torch.float32, device=device) + if batch_size_t2i == 0 and t2i_masks.shape[0] == 0: + t2i_masks = torch.empty((0, max_len), dtype=torch.long, device=device) + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Input ids shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels)) + + # with accelerator.accumulate(model): + logits, loss_t2i, loss_i2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_s2s, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_i2i=batch_size_i2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_s2t=batch_size_s2t, + batch_size_s2s=batch_size_s2s, + batch_size_t2s=total_batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + attention_masks_i2i=attention_masks_i2i, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_s2t=p_mask_s2t, + p_mask_s2s=p_mask_s2s, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_s2s=answer_lengths_s2s, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + ) + + # Gather the losses across all processes for logging (use reduce to avoid shape mismatches) + avg_loss_t2i = accelerator.reduce(loss_t2i, reduction='mean') + avg_loss_i2i = accelerator.reduce(loss_i2i, reduction='mean') + avg_loss_lm = accelerator.reduce(loss_lm, reduction='mean') + avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + avg_loss_vid = accelerator.reduce(loss_vid, reduction='mean') + avg_loss_s2t = accelerator.reduce(loss_s2t, reduction='mean') + avg_loss_s2s = accelerator.reduce(loss_s2s, reduction='mean') + avg_loss_t2s = accelerator.reduce(loss_t2s, reduction='mean') + + mmu_coeff = getattr(config.training, "mmu_coeff", 0.0) + i2i_coeff = getattr(config.training, "i2i_coeff", config.training.t2i_coeff) + s2s_coeff = getattr(config.training, "s2s_coeff", config.training.t2s_coeff) + loss = ( + config.training.t2i_coeff * loss_t2i + + i2i_coeff * loss_i2i + + config.training.lm_coeff * loss_lm + + mmu_coeff * loss_mmu + + config.training.v2t_coeff * loss_vid + + config.training.s2t_coeff * loss_s2t + + s2s_coeff * loss_s2s + + config.training.t2s_coeff * loss_t2s + ) + + if batch_size_t2i > 0: + avg_masking_rate = accelerator.reduce(mask_prob.float().mean(), reduction='mean') + else: + avg_masking_rate = torch.tensor(0.0, device=accelerator.device) + + if batch_size_i2i > 0: + avg_masking_rate_i2i = accelerator.reduce(mask_prob_i2i.float().mean(), reduction='mean') + else: + avg_masking_rate_i2i = torch.tensor(0.0, device=accelerator.device) + + if batch_size_s2s > 0: + avg_masking_rate_s2s = accelerator.reduce(p_mask_s2s.float().mean(), reduction='mean') + else: + avg_masking_rate_s2s = torch.tensor(0.0, device=accelerator.device) + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + "avg_masking_rate_i2i": avg_masking_rate_i2i.item(), + "avg_masking_rate_s2s": avg_masking_rate_s2s.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + + loss_entries = [ + ("step_loss_t2i", avg_loss_t2i), + ("step_loss_i2i", avg_loss_i2i), + ("step_loss_lm", avg_loss_lm), + ("step_loss_mmu", avg_loss_mmu), + ("step_loss_vid", avg_loss_vid), + ("step_loss_s2t", avg_loss_s2t), + ("step_loss_s2s", avg_loss_s2s), + ("step_loss_t2s", avg_loss_t2s), + ] + + loss_log_parts = [] + for key, value in loss_entries: + loss_value = value.item() + if loss_value != 0.0: + logs[key] = loss_value + loss_log_parts.append(f"{key.replace('step_', '').capitalize()}: {loss_value:0.4f}") + + accelerator.log(logs, step=global_step + 1) + + loss_str = " ".join(loss_log_parts) + logger.info( + "Step: %d %s Data (t): %.4f, %.2f/s/gpu Batch (t): %.4f LR: %.6f" + % ( + global_step + 1, + loss_str, + data_time_m.val, + samples_per_second_per_gpu, + batch_time_m.val, + lr_scheduler.get_last_lr()[0], + ) + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if global_step == 0 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_omada_inst_test_multi_image.py b/MMaDA/training/train_omada_inst_test_multi_image.py new file mode 100644 index 0000000000000000000000000000000000000000..ca127cb3a22ef20ef7cba3577dba896ad6d4acee --- /dev/null +++ b/MMaDA/training/train_omada_inst_test_multi_image.py @@ -0,0 +1,1881 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import cv2 +import glob +import random +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union +import csv +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset, MixedSpeechTextDataset, load_video_mp4, VideoCaptionDataset, S2T_INSTRUCTION, T2S_INSTRUCTION, TextImageInterleavedDataset +# import librosa + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } + +# ##### +# 추가 +# ##### + +def collate_fn_mmu_mult(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'images': [item['images'] for item in batch], + 'text': [item['text'] for item in batch], + } + +def collate_fn_video_caption(batch): + + batch = [item for item in batch if item is not None] + if len(batch) == 0: + return None + + frame_list = [] + input_ids_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + input_ids_list.append(item['caption']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "captions": input_ids_list # input_ids (B, seq_len) + } + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + for item in batch: + path = item['audio_path'] + tokens = vq_model_audio.encode(path) + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(32)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation"): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + eos_token = uni_prompting.text_tokenizer.eos_token + eos_marker = eos_token if eos_token is not None else "" + for i in range(len(decoded_texts)): + full_text = decoded_texts[i] + eos_idx = full_text.find(eos_marker) + cleaned_text = full_text[:eos_idx] if eos_idx != -1 else full_text + cleaned_text = cleaned_text.replace(eos_marker, "").strip() + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": cleaned_text, + }) + + if not local_results: + logger.warning("S2T evaluation produced no results.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["decoded_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + + samples_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for idx, res in enumerate(local_results): + sample_id = res.get("sample_id", idx) + samples_table.add_data(sample_id, res["gt_text"], res["decoded_text"]) + + accelerator.log({"eval/s2t_samples": samples_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for T2S: user prompt + text + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 384 - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=5.0, + temperature=1.0, + timesteps=50, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=383, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + if not local_results: + logger.warning("Skipping T2S evaluation logging because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """Text-to-speech evaluation using the MMU-style block refinement decoder.""" + + if not accelerator.is_main_process: + return + + logger.info("***** Running T2S Evaluation (MMU-style decoder) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as exc: + logger.error(f"Failed to load T2S dataset or Whisper model for MMU-style eval: {exc}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_mmu") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + + codebook_size = config.model.omada.codebook_size + speech_vocab_size = 4096 + + for batch in tqdm(t2s_dataloader, desc="T2S MMU Eval"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + + batch_size = len(prompts) + speech_token_length = 384 - 1 + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + output_ids = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=speech_token_length, + steps=384 - 1, + block_length=384 - 1, + temperature=1.0, + cfg_scale=1.5, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=speech_vocab_size, + ) + + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + if isinstance(gen_speech_tokens, torch.Tensor): + gen_speech_tokens = gen_speech_tokens.detach().cpu() + + token_list = gen_speech_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]} (MMU eval). Skipping.") + continue + + speech_unit_str = " ".join(map(str, token_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}_mmu.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + try: + vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + except Exception as exc: + logger.error(f"Decoding failed for sample {sample_ids[i]} (MMU eval): {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": output_wav_path, + }) + + if not local_results: + logger.warning("Skipping T2S MMU-style evaluation because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S (MMU-style) Final WER: {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_mmu_like_wer": wer, + "eval/t2s_mmu_like_word_errors": errors, + "eval/t2s_mmu_like_total_words": words, + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_mmu_like_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Text-to-Speech (fixed-length) evaluation: + - Input prompt contains SOA + [MASK]*L + EOA (EOA is injected, not predicted) + - The model only fills VQ codes for exactly L positions (no EOA/EOS prediction) + - Generated audio is transcribed by Whisper; we report WER + """ + if not accelerator.is_main_process: + return + logger.info("***** Running T2S (fixed-length) Evaluation *****") + unwrapped = accelerator.unwrap_model(model) + unwrapped.eval() + + # Load eval dataset and Whisper model + try: + ds_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3", + device=accelerator.device + ) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load dataset or Whisper model: {e}") + return + + # Directory for saving generated audio files of this evaluation step + out_dir = os.path.join( + "/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_fixed" + ) + os.makedirs(out_dir, exist_ok=True) + + eval_ds = T2SEvalDataset(ds_raw) + loader = DataLoader(eval_ds, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + for batch in tqdm(loader, desc="T2S Fixed Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for fixed-length T2S + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 256 - 2 # exclude and + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_fixed_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + # Core generation call: + # - predict_eoa=False prevents EOA/EOS prediction; only VQ codes are generated + outputs = unwrapped.t2s_fixed_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=1.5, + temperature=1.0, + timesteps=24, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=speech_token_length, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode generated VQ codes → waveform via the speech tokenizer, then ASR with Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + gen_rel = outputs[i] # relative VQ ids in [0..4095] + id_list = gen_rel.tolist() + + if not id_list: + logger.warning(f"[fixed] Empty tokens for {sample_ids[i]}; skipping.") + continue + + # Convert to the speech-unit string format expected by the decoder + unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")]) + + # Synthesize audio and run Whisper + fname = f"process_{accelerator.process_index}_{sample_ids[i]}_fixed.wav" + wav_path = os.path.join(out_dir, fname) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + _ = vq_model_audio.decode( + speech_unit_for_decode, + condition=condition, + output_wav_file=wav_path + ) + asr = whisper_pipe(wav_path, generate_kwargs={"language": "english"}) + whisper_text = asr.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": wav_path + }) + + if not local_results: + logger.warning("Skipping T2S fixed evaluation logging because no samples were generated.") + return + + gt_list = [r["gt_text"] for r in local_results] + pred_list = [r["whisper_text"] for r in local_results] + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Fixed WER: {wer:.4f} | Errors: {errors} | Words: {words}") + + accelerator.log({ + "eval/t2s_fixed_wer": wer, + "eval/t2s_fixed_errors": errors, + "eval/t2s_fixed_words": words + }, step=global_step) + + table = wandb.Table(columns=["ID", "GT", "ASR", "Audio"]) + for r in local_results[:8]: + table.add_data( + r["sample_id"], + r["gt_text"], + r["whisper_text"], + wandb.Audio(r["audio_path"], caption=r["whisper_text"]) + ) + accelerator.log({"eval/t2s_fixed_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + if accelerator.is_main_process: + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + if accelerator.is_main_process: + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) + total_batch_size = ( + (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + # Resize Vocab size for Audio Modality + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.omada.new_vocab_size}") + + if resized_vocab_size == config.model.omada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size = ( + (config.training.batch_size_t2s + config.training.batch_size_s2t +config.training.batch_size_v2t) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Video Dataset + video_captioning_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + num_frames=8, + ) + + sampler_v2t = DistributedSampler( + video_captioning_dataset, + shuffle=True, # Should be true for training + drop_last=True + ) + + train_dataloader_v2t = DataLoader( + video_captioning_dataset, + batch_size=config.training.batch_size_v2t, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_video_caption, + sampler = sampler_v2t, + drop_last=True, + ) + + # ##### + # 추가 + # ##### + ## multi image mmu dataset (mantis-instruct) + dataset_mmu = TextImageInterleavedDataset() + + # Speech Dataset + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + logger.info(f"Dataset Prepared.") + + # Use distinct DistributedSamplers for each speech dataloader to avoid iterator interference + if accelerator.num_processes > 1: + sampler_s2t = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + sampler_t2s = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + # ##### + # 추가 + # ##### + sampler_mmu = DistributedSampler( + dataset_mmu, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + else: + sampler_s2t = None + sampler_t2s = None + sampler_mmu = None + + train_dataloader_s2t = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_s2t, + shuffle=False, + sampler=sampler_s2t, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + train_dataloader_t2s = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_t2s, + shuffle=False, + sampler=sampler_t2s, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + # ##### + # 추가 + # ##### + train_dataloader_mmu = DataLoader( + dataset_mmu, + batch_size=config.training.batch_size_mmu, + shuffle=False, + sampler=sampler_mmu, + collate_fn=collate_fn_mmu_mult, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + + + + # Combine these dataloaders into a single iterable model + iterables = { + "v2t_flow": train_dataloader_v2t, + "t2s_flow": train_dataloader_t2s, + "s2t_flow": train_dataloader_s2t, + "mmu_flow": train_dataloader_mmu, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + # s2t + total_batch_size_s2t = config.training.batch_size_s2t * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_s2t = math.ceil(len(dataset_sm) / total_batch_size_s2t) + + # t2s + total_batch_size_t2s = config.training.batch_size_t2s * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_t2s = math.ceil(len(dataset_sm) / total_batch_size_t2s) + + # v2t + total_batch_size_v2t = (config.training.batch_size_v2t * accelerator.num_processes * config.training.gradient_accumulation_steps) + num_update_steps_per_epoch_v2t = math.ceil(len(video_captioning_dataset) / total_batch_size_v2t) + + # ##### + # 추가 + # ##### + # mmu + total_batch_size_mmu = (config.training.batch_size_mmu * accelerator.num_processes * config.training.gradient_accumulation_steps) + num_update_steps_per_epoch_mmu = math.ceil(len(dataset_mmu) / total_batch_size_mmu) + + + # Calculate num_train_epochs + num_update_steps_per_epoch = max(num_update_steps_per_epoch_s2t, num_update_steps_per_epoch_t2s, num_update_steps_per_epoch_v2t, num_update_steps_per_epoch_mmu) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + logger.info(f"len of speech: {len(dataset_sm)}") + logger.info(f"len of video: {len(video_captioning_dataset)}") + logger.info(f"len of image: {len(dataset_mmu)}") + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def safe_audio_encode(audio_path: str, flow_name: str): + try: + tokens = vq_model_audio.encode(audio_path) + return tokens, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + def safe_video_get_code(video_tensor_sample: torch.Tensor, sample_index: int): + try: + video_token = vq_model_image.get_code(video_tensor_sample) + return video_token, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] v2t video encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + # ##### + # 추가 + # ##### + def safe_image_get_code(image_tensor_sample: torch.Tensor, sample_index: int): + try: + image_token = vq_model_image.get_code(image_tensor_sample) + return image_token, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] image encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + # Ensure all samplers reshuffle in a rank-consistent way each epoch + try: + if isinstance(sampler_v2t, DistributedSampler): + sampler_v2t.set_epoch(epoch) + if accelerator.num_processes > 1: + if sampler_s2t is not None: + sampler_s2t.set_epoch(epoch) + if sampler_t2s is not None: + sampler_t2s.set_epoch(epoch) + except Exception: + pass + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + batch_size_t2i = 0 + batch_size_lm = 0 + + + # Synchronize skip decision across all ranks to avoid collective mismatches + local_skip = 1 if (batch is None or batch.get("v2t_flow") is None) else 0 + try: + skip_tensor = torch.tensor(local_skip, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction='sum') + should_skip = skip_sum.item() > 0 + except Exception: + # Fallback if reduce isn't available for any reason + should_skip = local_skip == 1 + + if should_skip: + if accelerator.is_main_process and local_skip: + logger.warning(f"Skipping step {global_step} (batch is None or v2t_flow missing) [synced]") + continue + + batch_size_v2t = len(batch["v2t_flow"]["video"]) + batch_size_t2s = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + batch_size_mmu = len(batch["mmu_flow"]["images"]) + + logger.info(f"batch_size_v2t: {batch_size_v2t}, batch_size_t2s: {batch_size_t2s}, batch_size_s2t: {batch_size_s2t}, batch_size_mmu: {batch_size_mmu}" ) + + # print(f"Rank {accelerator.process_index} loading data...") + # print(batch["s2t_flow"]["audio_path"]) + # print(batch["v2t_flow"]['captions']) + + audio_paths_s2t, texts_s2t = batch["s2t_flow"]["audio_path"], batch["s2t_flow"]["text"] + audio_paths_t2s, texts_t2s = batch["t2s_flow"]["audio_path"], batch["t2s_flow"]["text"] + offset = speech_vocab_start + video_tensor, texts_vid = batch["v2t_flow"]["video"], batch["v2t_flow"]["captions"] + image_tensor_list, texts_image = batch["mmu_flow"]["images"], batch["mmu_flow"]["text"] + + data_time_m.update(time.time() - end) + + failure_messages = [] + step_failed = False + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for video understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + for vid_idx, video in enumerate(video_tensor): # each video is (T, C, H, W) + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + video_token = tokens + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + if not step_failed: + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + input_ids_vid, prompt_masks_vid, labels_vid = uni_prompting((video_tokens, texts_vid), 'v2t') + # Keep trailing EOS tokens so v2t learns to emit explicit padding. + + ( + input_ids_vid, + labels_vid, + p_mask_vid, + answer_lengths_vid + ) = prepare_inputs_and_labels_for_mmu(input_ids_vid, prompt_masks_vid, labels_vid) + + input_ids_vid = input_ids_vid.to(accelerator.device, non_blocking=True) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_s2t: + tokens, err = safe_audio_encode(path, "s2t") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + # Preserve trailing EOS tokens in s2t targets for explicit prediction. + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_t2s: + tokens, err = safe_audio_encode(path, "t2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + # Chat-style instruction formatting for T2S training + prompt = random.choice(prompt_t2s) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in texts_t2s + ] + + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for multi image mmu + # ##### + # 추가 + # ##### + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + all_images = [] + + batch_image_ids_list = [] # List[List[LongTensor]] per-sample list of image token vectors + texts_mmu = [] # List[str] + + for b_idx, image_list in enumerate(image_tensor_list): # per sample + per_img_ids = [] + for j, img in enumerate(image_list): # per image in sample + # encode each image -> token ids (no truncation here) + tok, err = safe_image_get_code( + img.to(accelerator.device, non_blocking=True), + sample_index=j + ) + if err is not None: + failure_messages.append(err) + step_failed = True + break + + #flatten to 1D ids + tok = tok.to(accelerator.device, non_blocking=True) + tok = tok.view(-1).long() + per_img_ids.append(tok) + + if step_failed: + break + + batch_image_ids_list.append(per_img_ids) + texts_mmu.append(texts_image[b_idx]) + + if not step_failed: + # tokenize text (no special tokens; mmu_mult_prompt handles BOS/EOS + padding) + batch_text_ids = [ + uni_prompting.text_tokenizer.encode(t, add_special_tokens=False) + for t in texts_mmu + ] + + # build (img)* + text with fixed total len + input_ids_mmu, prompt_masks_mmu, labels_mmu = uni_prompting.mmu_mult_prompt( + batch_image_ids_list=batch_image_ids_list, + batch_text_ids=batch_text_ids, + ) + + ( + input_ids_mmu, + labels_mmu, + p_mask_mmu, + answer_lengths_mmu + ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks_mmu, labels_mmu) + + input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + prompt_masks_mmu = prompt_masks_mmu.to(accelerator.device, non_blocking=True) + labels_mmu = labels_mmu.to(accelerator.device, non_blocking=True) + + failure_tensor = torch.tensor(1 if step_failed else 0, device=accelerator.device, dtype=torch.int32) + failure_sum = accelerator.reduce(failure_tensor, reduction='sum') + if failure_sum.item() > 0: + if accelerator.is_main_process and failure_messages: + for msg in failure_messages: + logger.warning(f"Skipping global step {global_step} due to preprocessing failure: {msg}") + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + # --------------------------------------------------------------------------------- + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + max_len = max( + input_ids_vid.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1], + input_ids_mmu.shape[1], + ) + + # 3. Pad all tensors to the max_len + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + input_ids_mmu = pad_tensor(input_ids_mmu, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + labels_mmu = pad_tensor(labels_mmu, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + p_mask_mmu = pad_tensor(p_mask_mmu, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + answer_lengths_mmu = pad_answer_lengths(answer_lengths_mmu, max_len) + # --------------------------------------------------------------------------------- + + input_ids = torch.cat(( + input_ids_vid, + input_ids_s2t, + input_ids_t2s, + input_ids_mmu + ), dim=0) + labels = torch.cat(( + labels_vid, + labels_s2t, + labels_t2s, + labels_mmu + ), dim=0) + + # w/o texts and images + p_mask_lm = None + # p_mask_mmu = None + # answer_lengths_mmu = None + t2i_masks = None + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Input ids shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels)) + + # with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_s2t=batch_size_s2t, + batch_size_t2s=batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_s2t=p_mask_s2t, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + ) + + # Gather the losses across all processes for logging (use reduce to avoid shape mismatches) + # avg_loss_t2i = accelerator.reduce(loss_t2i, reduction='mean') + # avg_loss_lm = accelerator.reduce(loss_lm, reduction='mean') + # avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + + avg_loss_vid = accelerator.reduce(loss_vid, reduction='mean') + avg_loss_s2t = accelerator.reduce(loss_s2t, reduction='mean') + avg_loss_t2s = accelerator.reduce(loss_t2s, reduction='mean') + avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + + # loss = (config.training.t2i_coeff * loss_t2i + + # config.training.lm_coeff * loss_lm + + # config.training.mmu_coeff * loss_mmu + + # config.training.vid_coeff * loss_vid + + # config.training.s2t_coeff * loss_s2t + + # config.training.t2s_coeff * loss_t2s) + + loss = (config.training.v2t_coeff * loss_vid + + config.training.s2t_coeff * loss_s2t + + config.training.t2s_coeff * loss_t2s+ + config.training.mmu_coeff * loss_mmu) + + # HMM~~~~~ + avg_masking_rate = accelerator.reduce(p_mask_t2s.mean(), reduction='mean') + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + # "step_loss_t2i": avg_loss_t2i.item(), + # "step_loss_mmu": avg_loss_mmu.item(), + # "step_loss_lm": avg_loss_lm.item(), + "step_loss_vid": avg_loss_vid.item(), + "step_loss_s2t": avg_loss_s2t.item(), + "step_loss_t2s": avg_loss_t2s.item(), + "step_loss_mmu": avg_loss_mmu.item(), + "lr": lr_scheduler.get_last_lr()[0], + # "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + # f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + # f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + # f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Loss_vid: {avg_loss_vid.item():0.4f} " + f"Loss_s2t: {avg_loss_s2t.item():0.4f} " + f"Loss_t2s: {avg_loss_t2s.item():0.4f} " + f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if global_step == 60000 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_omada_stage1-2.py b/MMaDA/training/train_omada_stage1-2.py new file mode 100644 index 0000000000000000000000000000000000000000..607b9185abe6c7ea0ae56fb34b2b820666bfcca9 --- /dev/null +++ b/MMaDA/training/train_omada_stage1-2.py @@ -0,0 +1,1508 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import cv2 +import glob +import random +from tqdm import tqdm +from pathlib import Path +from typing import Union +import csv +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset, MixedSpeechTextDataset, load_video_mp4, VideoCaptionDataset, S2T_INSTRUCTION, T2S_INSTRUCTION +# import librosa + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } + +def collate_fn_video_caption(batch): + + batch = [item for item in batch if item is not None] + if len(batch) == 0: + return None + + frame_list = [] + input_ids_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + input_ids_list.append(item['caption']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "captions": input_ids_list # input_ids (B, seq_len) + } + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + config.model.omada.codebook_size + for item in batch: + path = item['audio_path'] + tokens = vq_model_audio.encode(path) + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(128)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + s2t_eval_dataloader = accelerator.prepare(s2t_eval_dataloader) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation", disable=not accelerator.is_main_process): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=128, block_length=64, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + for i in range(len(decoded_texts)): + local_results.append({"gt_text": gt_texts[i], "decoded_text": decoded_texts[i]}) + + # 3. Gather and Log Results + all_results = accelerator.gather_for_metrics(local_results) + + if accelerator.is_main_process: + if not all_results: + logger.warning("S2T evaluation produced no results.") + return + gt_list = [res["gt_text"] for res in all_results] + pred_list = [res["decoded_text"] for res in all_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model (only on main process for model download) + if accelerator.is_main_process: + try: + # Using a smaller subset for faster evaluation during training + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + whisper_pipe = None + + accelerator.wait_for_everyone() + # Re-initialize on other processes if main process succeeded + if not accelerator.is_main_process: + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + except Exception as e: + whisper_pipe = None + + if whisper_pipe is None: + logger.warning("Skipping T2S evaluation as Whisper or dataset failed to load.") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + t2s_dataloader = accelerator.prepare(t2s_dataloader) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + offset = len(uni_prompting.text_tokenizer) + config.model.omada.codebook_size + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation", disable=not accelerator.is_main_process): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [f"{text}\n{random.choice(T2S_INSTRUCTION)}" for text in gt_texts] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = config.dataset.preprocessing.max_aud_length - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=config.training.guidance_scale, + temperature=1.0, + timesteps=24, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=100, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + # 3. Gather and Log Results + all_results = accelerator.gather_for_metrics(local_results) + + if accelerator.is_main_process: + gt_list = [res["gt_text"] for res in all_results] + pred_list = [res["whisper_text"] for res in all_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + # Log some audio samples and transcriptions to W&B + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in all_results[:8]: # Log first 8 samples + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Text-to-Speech (fixed-length) evaluation: + - Input prompt contains SOA + [MASK]*L + EOA (EOA is injected, not predicted) + - The model only fills VQ codes for exactly L positions (no EOA/EOS prediction) + - Generated audio is transcribed by Whisper; we report WER + """ + logger.info("***** Running T2S (fixed-length) Evaluation *****") + unwrapped = accelerator.unwrap_model(model) + unwrapped.eval() + + # Load eval dataset and Whisper on the main process; re-init on others for distributed runs + if accelerator.is_main_process: + try: + ds_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3", + device=accelerator.device + ) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load dataset or Whisper model: {e}") + whisper_pipe = None + + accelerator.wait_for_everyone() + + # Initialize on non-main processes if needed + if not accelerator.is_main_process: + try: + ds_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3", + device=accelerator.device + ) + except Exception: + whisper_pipe = None + + if whisper_pipe is None: + logger.warning("Skipping T2S fixed evaluation due to missing Whisper/dataset.") + return + + # Directory for saving generated audio files of this evaluation step + out_dir = os.path.join( + "/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_fixed" + ) + os.makedirs(out_dir, exist_ok=True) + + eval_ds = T2SEvalDataset(ds_raw) + loader = DataLoader(eval_ds, batch_size=config.training.batch_size_t2s) + loader = accelerator.prepare(loader) + + local_results = [] + mask_token_id = unwrapped.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + for batch in tqdm(loader, desc="T2S Fixed Evaluation", disable=not accelerator.is_main_process): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [f"{text}\n{random.choice(T2S_INSTRUCTION)}" for text in gt_texts] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = config.dataset.preprocessing.max_aud_length - 2 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_fixed_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + # Core generation call: + # - predict_eoa=False prevents EOA/EOS prediction; only VQ codes are generated + outputs = unwrapped.t2s_fixed_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=1.5, + temperature=1.0, + timesteps=64, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=150, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode generated VQ codes → waveform via the speech tokenizer, then ASR with Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + gen_rel = outputs[i] # relative VQ ids in [0..4095] + id_list = gen_rel.tolist() + + if not id_list: + logger.warning(f"[fixed] Empty tokens for {sample_ids[i]}; skipping.") + continue + + # Convert to the speech-unit string format expected by the decoder + unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")]) + + # Synthesize audio and run Whisper + fname = f"process_{accelerator.process_index}_{sample_ids[i]}_fixed.wav" + wav_path = os.path.join(out_dir, fname) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + _ = vq_model_audio.decode( + speech_unit_for_decode, + condition=condition, + output_wav_file=wav_path + ) + asr = whisper_pipe(wav_path, generate_kwargs={"language": "english"}) + whisper_text = asr.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": wav_path + }) + + # Gather results across processes and compute WER on the main process + all_res = accelerator.gather_for_metrics(local_results) + if accelerator.is_main_process and all_res: + gt_list = [r["gt_text"] for r in all_res] + pred_list = [r["whisper_text"] for r in all_res] + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Fixed WER: {wer:.4f} | Errors: {errors} | Words: {words}") + + accelerator.log({ + "eval/t2s_fixed_wer": wer, + "eval/t2s_fixed_errors": errors, + "eval/t2s_fixed_words": words + }, step=global_step) + + # Log a small subset of samples to Weights & Biases + table = wandb.Table(columns=["ID", "GT", "ASR", "Audio"]) + for r in all_res[:8]: + table.add_data( + r["sample_id"], + r["gt_text"], + r["whisper_text"], + wandb.Audio(r["audio_path"], caption=r["whisper_text"]) + ) + accelerator.log({"eval/t2s_fixed_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + # --- Run S2T Evaluation --- + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + + # --- Run T2S Evaluation --- + evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + + # --- Run V2T Evaluation --- + evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) + total_batch_size = ( + (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + # Resize Vocab size for Audio Modality + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.omada.new_vocab_size}") + + if resized_vocab_size == config.model.omada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size = ( + (config.training.batch_size_t2s + config.training.batch_size_s2t +config.training.batch_size_v2t) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Video Dataset + video_captioning_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + num_frames=8, + ) + + sampler_v2t = DistributedSampler( + video_captioning_dataset, + shuffle=True, # Should be true for training + drop_last=True + ) + + train_dataloader_v2t = DataLoader( + video_captioning_dataset, + batch_size=config.training.batch_size_v2t, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_video_caption, + sampler = sampler_v2t + ) + + # Speech Dataset + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + logger.info(f"Dataset Prepared.") + + sampler_sm = DistributedSampler(dataset_sm, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True) if accelerator.num_processes > 1 else None + + train_dataloader_s2t = DataLoader(dataset_sm, batch_size=config.training.batch_size_s2t, shuffle=False, sampler=sampler_sm, collate_fn=collate_fn_audio, num_workers=config.dataset.params.num_workers) + train_dataloader_t2s = DataLoader(dataset_sm, batch_size=config.training.batch_size_t2s, shuffle=False, sampler=sampler_sm, collate_fn=collate_fn_audio, num_workers=config.dataset.params.num_workers) + + # Combine these dataloaders into a single iterable model + iterables = { + "v2t_flow": train_dataloader_v2t, + "t2s_flow": train_dataloader_t2s, + "s2t_flow": train_dataloader_s2t, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + # s2t + total_batch_size_s2t = config.training.batch_size_s2t * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_s2t = math.ceil(len(dataset_sm) / total_batch_size_s2t) + + # t2s + total_batch_size_t2s = config.training.batch_size_t2s * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_t2s = math.ceil(len(dataset_sm) / total_batch_size_t2s) + + # v2t + total_batch_size_v2t = (config.training.batch_size_v2t * accelerator.num_processes * config.training.gradient_accumulation_steps) + num_update_steps_per_epoch_v2t = math.ceil(len(video_captioning_dataset) / total_batch_size_v2t) + + + # Calculate num_train_epochs + num_update_steps_per_epoch = max(num_update_steps_per_epoch_s2t, num_update_steps_per_epoch_t2s, num_update_steps_per_epoch_v2t) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + logger.info(f"len of speech: {len(dataset_sm)}") + logger.info(f"len of video: {len(video_captioning_dataset)}") + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, mask_id=126336, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + batch_size_t2i = 0 + batch_size_lm = 0 + batch_size_mmu = 0 + + if batch is None or batch.get("v2t_flow") is None: + logger.warning(f"Skipping step {global_step} (batch is None or v2t_flow missing)") + continue + + batch_size_v2t = len(batch["v2t_flow"]["video"]) + batch_size_t2s = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + + logger.info(f"batch_size_v2t: {batch_size_v2t}, batch_size_t2s: {batch_size_t2s}, batch_size_s2t: {batch_size_s2t}" ) + + # print(f"Rank {accelerator.process_index} loading data...") + # print(batch["s2t_flow"]["audio_path"]) + # print(batch["v2t_flow"]['captions']) + + audio_paths_s2t, texts_s2t = batch["s2t_flow"]["audio_path"], batch["s2t_flow"]["text"] + audio_paths_t2s, texts_t2s = batch["t2s_flow"]["audio_path"], batch["t2s_flow"]["text"] + offset = speech_vocab_start + video_tensor, texts_vid = batch["v2t_flow"]["video"], batch["v2t_flow"]["captions"] + + data_time_m.update(time.time() - end) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for video understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + for video in video_tensor: # each video is (T, C, H, W) + video_token = vq_model_image.get_code(video) # (T, D) + # each video is tokenized into (T, D) + video_token = video_token + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + input_ids_vid, prompt_masks_vid, labels_vid = uni_prompting((video_tokens, texts_vid), 'v2t') + + ( + input_ids_vid, + labels_vid, + p_mask_vid, + answer_lengths_vid + ) = prepare_inputs_and_labels_for_mmu(input_ids_vid, prompt_masks_vid, labels_vid) + + input_ids_vid = input_ids_vid.to(accelerator.device, non_blocking=True) + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_s2t: + tokens = vq_model_audio.encode(path).to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_t2s: + tokens = vq_model_audio.encode(path).to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + prompt = random.choice(prompt_t2s) + texts_with_prompt = [f"{text}\n{prompt}" for text in texts_t2s] + + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + + # --------------------------------------------------------------------------------- + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + max_len = max( + input_ids_vid.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1] + ) + + # 3. Pad all tensors to the max_len + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + # --------------------------------------------------------------------------------- + + input_ids = torch.cat(( + input_ids_vid, + input_ids_s2t, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_vid, + labels_s2t, + labels_t2s + ), dim=0) + + # w/o texts and images + p_mask_lm = None + p_mask_mmu = None + answer_lengths_mmu = None + t2i_masks = None + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Input ids shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels)) + + # with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_s2t=batch_size_s2t, + batch_size_t2s=batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_s2t=p_mask_s2t, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + ) + + # Gather the losses across all processes for logging (if we use distributed training). + # avg_loss_t2i = accelerator.gather(loss_t2i.repeat(config.training.batch_size_t2i)).mean() + # avg_loss_lm = accelerator.gather(loss_lm.repeat(config.training.batch_size_lm)).mean() + # avg_loss_mmu = accelerator.gather(loss_mmu.repeat(config.training.batch_size_mmu)).mean() + + avg_loss_vid = accelerator.gather(loss_vid.repeat(config.training.batch_size_v2t)).mean() + avg_loss_s2t = accelerator.gather(loss_s2t.repeat(config.training.batch_size_s2t)).mean() + avg_loss_t2s = accelerator.gather(loss_t2s.repeat(config.training.batch_size_t2s)).mean() + + # loss = (config.training.t2i_coeff * loss_t2i + + # config.training.lm_coeff * loss_lm + + # config.training.mmu_coeff * loss_mmu + + # config.training.vid_coeff * loss_vid + + # config.training.s2t_coeff * loss_s2t + + # config.training.t2s_coeff * loss_t2s) + + loss = (config.training.v2t_coeff * loss_vid + + config.training.s2t_coeff * loss_s2t + + config.training.t2s_coeff * loss_t2s) + + # HMM~~~~~ + avg_masking_rate = accelerator.gather(p_mask_t2s.mean()).mean() + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + # "step_loss_t2i": avg_loss_t2i.item(), + # "step_loss_mmu": avg_loss_mmu.item(), + # "step_loss_lm": avg_loss_lm.item(), + "step_loss_vid": avg_loss_vid.item(), + "step_loss_s2t": avg_loss_s2t.item(), + "step_loss_t2s": avg_loss_t2s.item(), + "lr": lr_scheduler.get_last_lr()[0], + # "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + # f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + # f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + # f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Loss_vid: {avg_loss_vid.item():0.4f} " + f"Loss_s2t: {avg_loss_s2t.item():0.4f} " + f"Loss_t2s: {avg_loss_t2s.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if global_step == 0 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_omada_stage1-3.py b/MMaDA/training/train_omada_stage1-3.py new file mode 100644 index 0000000000000000000000000000000000000000..77de37a91eebeaf5fb85a4f1ab86d9c9be92506d --- /dev/null +++ b/MMaDA/training/train_omada_stage1-3.py @@ -0,0 +1,1748 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import cv2 +import glob +import random +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union +import csv +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset, MixedSpeechTextDataset, load_video_mp4, VideoCaptionDataset, S2T_INSTRUCTION, T2S_INSTRUCTION +# import librosa + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } + +def collate_fn_video_caption(batch): + + batch = [item for item in batch if item is not None] + if len(batch) == 0: + return None + + frame_list = [] + input_ids_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + input_ids_list.append(item['caption']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "captions": input_ids_list # input_ids (B, seq_len) + } + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + for item in batch: + path = item['audio_path'] + tokens = vq_model_audio.encode(path) + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(32)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation"): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + eos_token = uni_prompting.text_tokenizer.eos_token + eos_marker = eos_token if eos_token is not None else "" + for i in range(len(decoded_texts)): + full_text = decoded_texts[i] + eos_idx = full_text.find(eos_marker) + cleaned_text = full_text[:eos_idx] if eos_idx != -1 else full_text + cleaned_text = cleaned_text.replace(eos_marker, "").strip() + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": cleaned_text, + }) + + if not local_results: + logger.warning("S2T evaluation produced no results.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["decoded_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + + samples_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for idx, res in enumerate(local_results): + sample_id = res.get("sample_id", idx) + samples_table.add_data(sample_id, res["gt_text"], res["decoded_text"]) + + accelerator.log({"eval/s2t_samples": samples_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for T2S: user prompt + text + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 384 - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=5.0, + temperature=1.0, + timesteps=50, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=383, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + if not local_results: + logger.warning("Skipping T2S evaluation logging because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """Text-to-speech evaluation using the MMU-style block refinement decoder.""" + + if not accelerator.is_main_process: + return + + logger.info("***** Running T2S Evaluation (MMU-style decoder) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as exc: + logger.error(f"Failed to load T2S dataset or Whisper model for MMU-style eval: {exc}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_mmu") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + + codebook_size = config.model.omada.codebook_size + speech_vocab_size = 4096 + + for batch in tqdm(t2s_dataloader, desc="T2S MMU Eval"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + + batch_size = len(prompts) + speech_token_length = 384 - 1 + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + output_ids = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=speech_token_length, + steps=384 - 1, + block_length=384 - 1, + temperature=1.0, + cfg_scale=3.0, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=speech_vocab_size, + ) + + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + if isinstance(gen_speech_tokens, torch.Tensor): + gen_speech_tokens = gen_speech_tokens.detach().cpu() + + token_list = gen_speech_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]} (MMU eval). Skipping.") + continue + + speech_unit_str = " ".join(map(str, token_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}_mmu.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + try: + vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + except Exception as exc: + logger.error(f"Decoding failed for sample {sample_ids[i]} (MMU eval): {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": output_wav_path, + }) + + if not local_results: + logger.warning("Skipping T2S MMU-style evaluation because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S (MMU-style) Final WER: {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_mmu_like_wer": wer, + "eval/t2s_mmu_like_word_errors": errors, + "eval/t2s_mmu_like_total_words": words, + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_mmu_like_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Text-to-Speech (fixed-length) evaluation: + - Input prompt contains SOA + [MASK]*L + EOA (EOA is injected, not predicted) + - The model only fills VQ codes for exactly L positions (no EOA/EOS prediction) + - Generated audio is transcribed by Whisper; we report WER + """ + if not accelerator.is_main_process: + return + logger.info("***** Running T2S (fixed-length) Evaluation *****") + unwrapped = accelerator.unwrap_model(model) + unwrapped.eval() + + # Load eval dataset and Whisper model + try: + ds_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3", + device=accelerator.device + ) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load dataset or Whisper model: {e}") + return + + # Directory for saving generated audio files of this evaluation step + out_dir = os.path.join( + "/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_fixed" + ) + os.makedirs(out_dir, exist_ok=True) + + eval_ds = T2SEvalDataset(ds_raw) + loader = DataLoader(eval_ds, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + for batch in tqdm(loader, desc="T2S Fixed Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for fixed-length T2S + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 256 - 2 # exclude and + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_fixed_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + # Core generation call: + # - predict_eoa=False prevents EOA/EOS prediction; only VQ codes are generated + outputs = unwrapped.t2s_fixed_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=1.5, + temperature=1.0, + timesteps=24, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=speech_token_length, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode generated VQ codes → waveform via the speech tokenizer, then ASR with Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + gen_rel = outputs[i] # relative VQ ids in [0..4095] + id_list = gen_rel.tolist() + + if not id_list: + logger.warning(f"[fixed] Empty tokens for {sample_ids[i]}; skipping.") + continue + + # Convert to the speech-unit string format expected by the decoder + unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")]) + + # Synthesize audio and run Whisper + fname = f"process_{accelerator.process_index}_{sample_ids[i]}_fixed.wav" + wav_path = os.path.join(out_dir, fname) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + _ = vq_model_audio.decode( + speech_unit_for_decode, + condition=condition, + output_wav_file=wav_path + ) + asr = whisper_pipe(wav_path, generate_kwargs={"language": "english"}) + whisper_text = asr.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": wav_path + }) + + if not local_results: + logger.warning("Skipping T2S fixed evaluation logging because no samples were generated.") + return + + gt_list = [r["gt_text"] for r in local_results] + pred_list = [r["whisper_text"] for r in local_results] + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Fixed WER: {wer:.4f} | Errors: {errors} | Words: {words}") + + accelerator.log({ + "eval/t2s_fixed_wer": wer, + "eval/t2s_fixed_errors": errors, + "eval/t2s_fixed_words": words + }, step=global_step) + + table = wandb.Table(columns=["ID", "GT", "ASR", "Audio"]) + for r in local_results[:8]: + table.add_data( + r["sample_id"], + r["gt_text"], + r["whisper_text"], + wandb.Audio(r["audio_path"], caption=r["whisper_text"]) + ) + accelerator.log({"eval/t2s_fixed_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + if accelerator.is_main_process: + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + if accelerator.is_main_process: + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) + total_batch_size = ( + (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + # Resize Vocab size for Audio Modality + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.omada.new_vocab_size}") + + if resized_vocab_size == config.model.omada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size = ( + (config.training.batch_size_t2s + config.training.batch_size_s2t +config.training.batch_size_v2t) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Video Dataset + video_captioning_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + num_frames=8, + ) + + sampler_v2t = DistributedSampler( + video_captioning_dataset, + shuffle=True, # Should be true for training + drop_last=True + ) + + train_dataloader_v2t = DataLoader( + video_captioning_dataset, + batch_size=config.training.batch_size_v2t, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_video_caption, + sampler = sampler_v2t, + drop_last=True, + ) + + # Speech Dataset + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + logger.info(f"Dataset Prepared.") + + # Use distinct DistributedSamplers for each speech dataloader to avoid iterator interference + if accelerator.num_processes > 1: + sampler_s2t = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + sampler_t2s = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + else: + sampler_s2t = None + sampler_t2s = None + + train_dataloader_s2t = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_s2t, + shuffle=False, + sampler=sampler_s2t, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + train_dataloader_t2s = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_t2s, + shuffle=False, + sampler=sampler_t2s, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + + # Combine these dataloaders into a single iterable model + iterables = { + "v2t_flow": train_dataloader_v2t, + "t2s_flow": train_dataloader_t2s, + "s2t_flow": train_dataloader_s2t, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + # s2t + total_batch_size_s2t = config.training.batch_size_s2t * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_s2t = math.ceil(len(dataset_sm) / total_batch_size_s2t) + + # t2s + total_batch_size_t2s = config.training.batch_size_t2s * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_t2s = math.ceil(len(dataset_sm) / total_batch_size_t2s) + + # v2t + total_batch_size_v2t = (config.training.batch_size_v2t * accelerator.num_processes * config.training.gradient_accumulation_steps) + num_update_steps_per_epoch_v2t = math.ceil(len(video_captioning_dataset) / total_batch_size_v2t) + + + # Calculate num_train_epochs + num_update_steps_per_epoch = max(num_update_steps_per_epoch_s2t, num_update_steps_per_epoch_t2s, num_update_steps_per_epoch_v2t) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + logger.info(f"len of speech: {len(dataset_sm)}") + logger.info(f"len of video: {len(video_captioning_dataset)}") + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def safe_audio_encode(audio_path: str, flow_name: str): + try: + tokens = vq_model_audio.encode(audio_path) + return tokens, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + def safe_video_get_code(video_tensor_sample: torch.Tensor, sample_index: int): + try: + video_token = vq_model_image.get_code(video_tensor_sample) + return video_token, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] v2t video encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + # Ensure all samplers reshuffle in a rank-consistent way each epoch + try: + if isinstance(sampler_v2t, DistributedSampler): + sampler_v2t.set_epoch(epoch) + if accelerator.num_processes > 1: + if sampler_s2t is not None: + sampler_s2t.set_epoch(epoch) + if sampler_t2s is not None: + sampler_t2s.set_epoch(epoch) + except Exception: + pass + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + batch_size_t2i = 0 + batch_size_lm = 0 + batch_size_mmu = 0 + + # Synchronize skip decision across all ranks to avoid collective mismatches + local_skip = 1 if (batch is None or batch.get("v2t_flow") is None) else 0 + try: + skip_tensor = torch.tensor(local_skip, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction='sum') + should_skip = skip_sum.item() > 0 + except Exception: + # Fallback if reduce isn't available for any reason + should_skip = local_skip == 1 + + if should_skip: + if accelerator.is_main_process and local_skip: + logger.warning(f"Skipping step {global_step} (batch is None or v2t_flow missing) [synced]") + continue + + batch_size_v2t = len(batch["v2t_flow"]["video"]) + batch_size_t2s = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + + logger.info(f"batch_size_v2t: {batch_size_v2t}, batch_size_t2s: {batch_size_t2s}, batch_size_s2t: {batch_size_s2t}" ) + + # print(f"Rank {accelerator.process_index} loading data...") + # print(batch["s2t_flow"]["audio_path"]) + # print(batch["v2t_flow"]['captions']) + + audio_paths_s2t, texts_s2t = batch["s2t_flow"]["audio_path"], batch["s2t_flow"]["text"] + audio_paths_t2s, texts_t2s = batch["t2s_flow"]["audio_path"], batch["t2s_flow"]["text"] + offset = speech_vocab_start + video_tensor, texts_vid = batch["v2t_flow"]["video"], batch["v2t_flow"]["captions"] + + data_time_m.update(time.time() - end) + + failure_messages = [] + step_failed = False + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for video understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + prompt_v2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in V2T_INSTRUCTION] + + for vid_idx, video in enumerate(video_tensor): # each video is (T, C, H, W) + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + video_token = tokens + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + if not step_failed: + + prompt_v2t_cur = random.choice(prompt_v2t) + + texts_with_prompt = [f"{prompt}{text}" for text in texts_v2t] + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + input_ids_vid, prompt_masks_vid, labels_vid = uni_prompting((video_tokens, texts_vid), 'v2t') + # Keep trailing EOS tokens so v2t learns to emit explicit padding. + + ( + input_ids_vid, + labels_vid, + p_mask_vid, + answer_lengths_vid + ) = prepare_inputs_and_labels_for_mmu(input_ids_vid, prompt_masks_vid, labels_vid) + + input_ids_vid = input_ids_vid.to(accelerator.device, non_blocking=True) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_s2t: + tokens, err = safe_audio_encode(path, "s2t") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + # Preserve trailing EOS tokens in s2t targets for explicit prediction. + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_t2s: + tokens, err = safe_audio_encode(path, "t2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + # Chat-style instruction formatting for T2S training + prompt = random.choice(prompt_t2s) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in texts_t2s + ] + + # input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + + failure_tensor = torch.tensor(1 if step_failed else 0, device=accelerator.device, dtype=torch.int32) + failure_sum = accelerator.reduce(failure_tensor, reduction='sum') + if failure_sum.item() > 0: + if accelerator.is_main_process and failure_messages: + for msg in failure_messages: + logger.warning(f"Skipping global step {global_step} due to preprocessing failure: {msg}") + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + # --------------------------------------------------------------------------------- + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + max_len = max( + input_ids_vid.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1] + ) + + # 3. Pad all tensors to the max_len + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + # --------------------------------------------------------------------------------- + + input_ids = torch.cat(( + input_ids_vid, + input_ids_s2t, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_vid, + labels_s2t, + labels_t2s + ), dim=0) + + # w/o texts and images + p_mask_lm = None + p_mask_mmu = None + answer_lengths_mmu = None + t2i_masks = None + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Input ids shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels)) + + # with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_s2t=batch_size_s2t, + batch_size_t2s=batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_s2t=p_mask_s2t, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + ) + + # Gather the losses across all processes for logging (use reduce to avoid shape mismatches) + # avg_loss_t2i = accelerator.reduce(loss_t2i, reduction='mean') + # avg_loss_lm = accelerator.reduce(loss_lm, reduction='mean') + # avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + + avg_loss_vid = accelerator.reduce(loss_vid, reduction='mean') + avg_loss_s2t = accelerator.reduce(loss_s2t, reduction='mean') + avg_loss_t2s = accelerator.reduce(loss_t2s, reduction='mean') + + # loss = (config.training.t2i_coeff * loss_t2i + + # config.training.lm_coeff * loss_lm + + # config.training.mmu_coeff * loss_mmu + + # config.training.vid_coeff * loss_vid + + # config.training.s2t_coeff * loss_s2t + + # config.training.t2s_coeff * loss_t2s) + + loss = (config.training.v2t_coeff * loss_vid + + config.training.s2t_coeff * loss_s2t + + config.training.t2s_coeff * loss_t2s) + + # HMM~~~~~ + avg_masking_rate = accelerator.reduce(p_mask_t2s.mean(), reduction='mean') + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + # "step_loss_t2i": avg_loss_t2i.item(), + # "step_loss_mmu": avg_loss_mmu.item(), + # "step_loss_lm": avg_loss_lm.item(), + "step_loss_vid": avg_loss_vid.item(), + "step_loss_s2t": avg_loss_s2t.item(), + "step_loss_t2s": avg_loss_t2s.item(), + "lr": lr_scheduler.get_last_lr()[0], + # "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + # f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + # f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + # f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Loss_vid: {avg_loss_vid.item():0.4f} " + f"Loss_s2t: {avg_loss_s2t.item():0.4f} " + f"Loss_t2s: {avg_loss_t2s.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if global_step == 0 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_omada_stage1-3_plus_t2i_mmu_lm.py b/MMaDA/training/train_omada_stage1-3_plus_t2i_mmu_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..619a3989425530ad9182c259c53317df0e7c6cd2 --- /dev/null +++ b/MMaDA/training/train_omada_stage1-3_plus_t2i_mmu_lm.py @@ -0,0 +1,1742 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import cv2 +import glob +import random +from tqdm import tqdm +from pathlib import Path +from typing import Optional, Union +import csv +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset, MixedSpeechTextDataset, load_video_mp4, VideoCaptionDataset, S2T_INSTRUCTION, T2S_INSTRUCTION +# import librosa + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# ++++++++ EVALUATION IMPORTS ++++++++ +import re +import editdistance +import soundfile as sf +from functools import partial +from transformers import pipeline +# ++++++++++++++++++++++++++++++++++++ + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } + +def collate_fn_video_caption(batch): + + batch = [item for item in batch if item is not None] + if len(batch) == 0: + return None + + frame_list = [] + input_ids_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + input_ids_list.append(item['caption']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "captions": input_ids_list # input_ids (B, seq_len) + } + +def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, config): + + audio_tokens_batch = [] + offset = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + for item in batch: + path = item['audio_path'] + tokens = vq_model_audio.encode(path) + tokens_with_offset = tokens + offset + audio_tokens_batch.append(tokens_with_offset) + + sptids_dict = uni_prompting.sptids_dict + device = audio_tokens_batch[0].device + batched_input_ids = [] + + for audio_tokens in audio_tokens_batch: + task_tensor = sptids_dict['<|s2t|>'].to(device).unsqueeze(0) + soa_tensor = sptids_dict['<|soa|>'].to(device).unsqueeze(0) + eoa_tensor = sptids_dict['<|eoa|>'].to(device).unsqueeze(0) + audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1) + + prompt_text = random.choice(S2T_INSTRUCTION) + full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device) + + final_sequence = torch.cat([audio_block, prompt_tensor], dim=1) + batched_input_ids.append(final_sequence.squeeze(0)) + + max_len = max(seq.size(0) for seq in batched_input_ids) + pad_token_id = 126093 + + final_batch_input_ids = torch.full( + (len(batched_input_ids), max_len), + pad_token_id, + dtype=torch.long, + device=device + ) + + for i, seq in enumerate(batched_input_ids): + final_batch_input_ids[i, -len(seq):] = seq + + return { + "input_ids": final_batch_input_ids, + "gt_texts": [item['gt_text'] for item in batch], + "sample_ids": [item['sample_id'] for item in batch] + } + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ EVALUATION HELPERS +++++++++++++++++++++++++++++++++++++ +################################################################################################ + +def normalize_text(text): + """A simple normalizer for WER calculation.""" + text = text.lower() + text = re.sub(r"[^\w\s']", "", text) + return text + +def calculate_wer(predictions, references): + """Calculates the Word Error Rate (WER) between predicted and ground truth texts.""" + predictions = [normalize_text(p) for p in predictions] + references = [normalize_text(r) for r in references] + + total_errors = 0 + total_words = 0 + for pred, ref in zip(predictions, references): + pred_words = pred.split() + ref_words = ref.split() + total_errors += editdistance.eval(pred_words, ref_words) + total_words += len(ref_words) + + wer = total_errors / total_words if total_words > 0 else 0.0 + return wer, total_errors, total_words + +class S2TEvalDataset(Dataset): + def __init__(self, hf_dataset, root_path): + self.hf_dataset = hf_dataset + self.root_path = root_path + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + example = self.hf_dataset[idx] + sample_id = example['id'] + speaker_id, chapter_id, _ = sample_id.split('-') + audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac") + + return { + "audio_path": audio_path, + "gt_text": example["text"], + "sample_id": sample_id + } + +# --- T2S Evaluation Dataset --- +class T2SEvalDataset(Dataset): + def __init__(self, hf_dataset): + self.hf_dataset = hf_dataset + def __len__(self): + return len(self.hf_dataset) + def __getitem__(self, idx): + example = self.hf_dataset[idx] + return {"gt_text": example['text'], "sample_id": example['id']} + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ S2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running S2T Evaluation (WER on Librispeech test-clean) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset + try: + s2t_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test", streaming=False).select(range(32)) + s2t_eval_dataset = S2TEvalDataset(s2t_eval_dataset_raw, root_path = "/home/work/AIDAS/data/audio/LibriSpeech/test-clean") + except Exception as e: + logger.error(f"Failed to load S2T evaluation dataset: {e}") + return + + collate_with_args = partial( + s2t_eval_collate_fn, + vq_model_audio=vq_model_audio, + tokenizer=uni_prompting.text_tokenizer, + uni_prompting=uni_prompting, + config=config + ) + + s2t_eval_dataloader = DataLoader(s2t_eval_dataset, batch_size=config.training.batch_size_s2t, shuffle=False, collate_fn=collate_with_args) + + local_results = [] + + for batch in tqdm(s2t_eval_dataloader, desc="S2T Evaluation"): + input_ids = batch["input_ids"] + gt_texts = batch["gt_texts"] + sample_ids = batch["sample_ids"] + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128, remasking='low_confidence') + + decoded_texts = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + + eos_token = uni_prompting.text_tokenizer.eos_token + eos_marker = eos_token if eos_token is not None else "" + for i in range(len(decoded_texts)): + full_text = decoded_texts[i] + eos_idx = full_text.find(eos_marker) + cleaned_text = full_text[:eos_idx] if eos_idx != -1 else full_text + cleaned_text = cleaned_text.replace(eos_marker, "").strip() + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt_texts[i], + "decoded_text": cleaned_text, + }) + + if not local_results: + logger.warning("S2T evaluation produced no results.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["decoded_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"S2T Final WER (Librispeech test-clean): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/s2t_wer": wer, + "eval/s2t_word_errors": errors, + "eval/s2t_total_words": words + }, step=global_step) + + samples_table = wandb.Table(columns=["ID", "Ground Truth", "Prediction"]) + for idx, res in enumerate(local_results): + sample_id = res.get("sample_id", idx) + samples_table.add_data(sample_id, res["gt_text"], res["decoded_text"]) + + accelerator.log({"eval/s2t_samples": samples_table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ T2S EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + if not accelerator.is_main_process: + return + logger.info("***** Running T2S Evaluation (WER via Whisper on Librispeech) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + # 1. Load Dataset & Whisper Model + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load T2S dataset or Whisper model: {e}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # 2. Evaluation Loop + for batch in tqdm(t2s_dataloader, desc="T2S Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for T2S: user prompt + text + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 384 - 1 # -1 for soa token + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + output_ids = unwrapped_model.t2s_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=5.0, + temperature=1.0, + timesteps=50, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=383, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode and run Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + + # Remove padding/eos if necessary, clamp to valid range + # gen_speech_tokens = torch.clamp(gen_speech_tokens, min=0, max= 4096 - 1) + id_list = gen_speech_tokens.cpu().tolist() + + if not id_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]}. Skipping.") + continue + + speech_unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + audio_array = vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], "gt_text": gt, "whisper_text": whisper_text, "audio_path": output_wav_path + }) + + if not local_results: + logger.warning("Skipping T2S evaluation logging because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Final WER (via Whisper): {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_wer": wer, + "eval/t2s_word_errors": errors, + "eval/t2s_total_words": words + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """Text-to-speech evaluation using the MMU-style block refinement decoder.""" + + if not accelerator.is_main_process: + return + + logger.info("***** Running T2S Evaluation (MMU-style decoder) *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + try: + t2s_eval_dataset_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(32)) + whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3", device=accelerator.device) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as exc: + logger.error(f"Failed to load T2S dataset or Whisper model for MMU-style eval: {exc}") + return + + output_dir_per_step = os.path.join("/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_mmu") + os.makedirs(output_dir_per_step, exist_ok=True) + + t2s_eval_dataset = T2SEvalDataset(t2s_eval_dataset_raw) + t2s_dataloader = DataLoader(t2s_eval_dataset, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped_model.config.mask_token_id + + codebook_size = config.model.omada.codebook_size + speech_vocab_size = 4096 + + for batch in tqdm(t2s_dataloader, desc="T2S MMU Eval"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + + batch_size = len(prompts) + speech_token_length = 384 - 1 + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') + + output_ids = unwrapped_model.t2s_generate_mmu_like( + input_ids=input_ids, + max_new_tokens=speech_token_length, + steps=384 - 1, + block_length=384 - 1, + temperature=1.0, + cfg_scale=3.0, + mask_token_id=mask_token_id, + attention_mask=attention_mask, + uni_prompting=uni_prompting, + codebook_size=codebook_size, + audio_codebook_size=speech_vocab_size, + ) + + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + + gen_speech_tokens = output_ids[i] + if isinstance(gen_speech_tokens, torch.Tensor): + gen_speech_tokens = gen_speech_tokens.detach().cpu() + + token_list = gen_speech_tokens.tolist() + if not token_list: + logger.warning(f"Generated token list is empty for sample {sample_ids[i]} (MMU eval). Skipping.") + continue + + speech_unit_str = " ".join(map(str, token_list)) + speech_unit_for_decode = "".join([f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ")]) + + filename = f"process_{accelerator.process_index}_{sample_ids[i]}_mmu.wav" + output_wav_path = os.path.join(output_dir_per_step, filename) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + try: + vq_model_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_path) + except Exception as exc: + logger.error(f"Decoding failed for sample {sample_ids[i]} (MMU eval): {exc}") + continue + + whisper_result = whisper_pipe(output_wav_path, generate_kwargs={"language": "english"}) + whisper_text = whisper_result.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": output_wav_path, + }) + + if not local_results: + logger.warning("Skipping T2S MMU-style evaluation because no samples were generated.") + return + + gt_list = [res["gt_text"] for res in local_results] + pred_list = [res["whisper_text"] for res in local_results] + + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S (MMU-style) Final WER: {wer:.4f} | Word Errors: {errors} | Total Words: {words}") + + accelerator.log({ + "eval/t2s_mmu_like_wer": wer, + "eval/t2s_mmu_like_word_errors": errors, + "eval/t2s_mmu_like_total_words": words, + }, step=global_step) + + results_table = wandb.Table(columns=["ID", "Ground Truth", "Whisper Transcription", "Generated Audio"]) + for res in local_results[:8]: + audio = wandb.Audio(res["audio_path"], caption=res["whisper_text"]) + results_table.add_data(res["sample_id"], res["gt_text"], res["whisper_text"], audio) + + accelerator.log({"eval/t2s_mmu_like_samples": results_table}, step=global_step) + +@torch.no_grad() +def evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Text-to-Speech (fixed-length) evaluation: + - Input prompt contains SOA + [MASK]*L + EOA (EOA is injected, not predicted) + - The model only fills VQ codes for exactly L positions (no EOA/EOS prediction) + - Generated audio is transcribed by Whisper; we report WER + """ + if not accelerator.is_main_process: + return + logger.info("***** Running T2S (fixed-length) Evaluation *****") + unwrapped = accelerator.unwrap_model(model) + unwrapped.eval() + + # Load eval dataset and Whisper model + try: + ds_raw = load_dataset("librispeech_asr", "clean", split="test").select(range(128)) + whisper_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3", + device=accelerator.device + ) + os.makedirs(f"{config.experiment.output_dir}/eval_audio", exist_ok=True) + except Exception as e: + logger.error(f"Failed to load dataset or Whisper model: {e}") + return + + # Directory for saving generated audio files of this evaluation step + out_dir = os.path.join( + "/home/work/AIDAS", config.experiment.output_dir, "eval_audio", f"step_{global_step}_fixed" + ) + os.makedirs(out_dir, exist_ok=True) + + eval_ds = T2SEvalDataset(ds_raw) + loader = DataLoader(eval_ds, batch_size=config.training.batch_size_t2s) + + local_results = [] + mask_token_id = unwrapped.config.mask_token_id + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + for batch in tqdm(loader, desc="T2S Fixed Evaluation"): + gt_texts = batch["gt_text"] + sample_ids = batch["sample_id"] + + # Chat-style instruction formatting for fixed-length T2S + prompts = [ + f"<|start_header_id|>user<|end_header_id|>\n{random.choice(T2S_INSTRUCTION)}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in gt_texts + ] + batch_size = len(prompts) + + # We need a reasonable length for generated audio tokens + speech_token_length = 256 - 2 # exclude and + audio_tokens = torch.ones((batch_size, speech_token_length), dtype=torch.long, device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen') + + if config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * batch_size, audio_tokens), 't2s_fixed_gen') + else: + uncond_input_ids, uncond_attention_mask = None, None + + # Core generation call: + # - predict_eoa=False prevents EOA/EOS prediction; only VQ codes are generated + outputs = unwrapped.t2s_fixed_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=1.5, + temperature=1.0, + timesteps=24, + noise_schedule=mask_schedule, + noise_type="mask", + seq_len=speech_token_length, + uni_prompting=uni_prompting, + config=config, + ) + + # Decode generated VQ codes → waveform via the speech tokenizer, then ASR with Whisper + for i in range(batch_size): + gt = gt_texts[i].rsplit("\n", 1)[-1].strip() + gen_rel = outputs[i] # relative VQ ids in [0..4095] + id_list = gen_rel.tolist() + + if not id_list: + logger.warning(f"[fixed] Empty tokens for {sample_ids[i]}; skipping.") + continue + + # Convert to the speech-unit string format expected by the decoder + unit_str = " ".join(map(str, id_list)) + speech_unit_for_decode = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")]) + + # Synthesize audio and run Whisper + fname = f"process_{accelerator.process_index}_{sample_ids[i]}_fixed.wav" + wav_path = os.path.join(out_dir, fname) + condition = 'gender-female_emotion-neutral_speed-normal_pitch-normal' + + _ = vq_model_audio.decode( + speech_unit_for_decode, + condition=condition, + output_wav_file=wav_path + ) + asr = whisper_pipe(wav_path, generate_kwargs={"language": "english"}) + whisper_text = asr.get("text", "") + + local_results.append({ + "sample_id": sample_ids[i], + "gt_text": gt, + "whisper_text": whisper_text, + "audio_path": wav_path + }) + + if not local_results: + logger.warning("Skipping T2S fixed evaluation logging because no samples were generated.") + return + + gt_list = [r["gt_text"] for r in local_results] + pred_list = [r["whisper_text"] for r in local_results] + wer, errors, words = calculate_wer(pred_list, gt_list) + logger.info(f"T2S Fixed WER: {wer:.4f} | Errors: {errors} | Words: {words}") + + accelerator.log({ + "eval/t2s_fixed_wer": wer, + "eval/t2s_fixed_errors": errors, + "eval/t2s_fixed_words": words + }, step=global_step) + + table = wandb.Table(columns=["ID", "GT", "ASR", "Audio"]) + for r in local_results[:8]: + table.add_data( + r["sample_id"], + r["gt_text"], + r["whisper_text"], + wandb.Audio(r["audio_path"], caption=r["whisper_text"]) + ) + accelerator.log({"eval/t2s_fixed_samples": table}, step=global_step) + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ V2T EVALUATION LOGIC +++++++++++++++++++++++++++++++++++++ +################################################################################################ +@torch.no_grad() +def evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step): + # This is a qualitative evaluation, so it only runs on the main process. + if not accelerator.is_main_process: + return + + logger.info("***** Running V2T Qualitative Evaluation *****") + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.eval() + + video_root = "/home/work/AIDAS/video/demo" + if not video_root or not os.path.exists(video_root): + logger.warning(f"V2T eval root '{video_root}' not found. Skipping V2T evaluation.") + return + + file_list = [f for f in os.listdir(video_root) if f.lower().endswith('.mp4')] + if not file_list: + logger.warning(f"No .mp4 files found in '{video_root}'. Skipping V2T evaluation.") + return + + question = "Please provide a detailed description of the video." + results_table = wandb.Table(columns=["Video ID", "Question", "Generated Caption"]) + + for file_name in tqdm(file_list[:], desc="V2T Evaluation", disable=not accelerator.is_main_process): + video_path = os.path.join(video_root, file_name) + + # 1. Load and process video + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + indices = np.linspace(0, total_frames - 1, 8, dtype=int) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if i in indices: + if not ret: continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(frame) + frames.append(image_transform(pil_img, resolution=config.dataset.preprocessing.resolution)) + cap.release() + + if len(frames) < 8: continue + + video_tensor = torch.stack(frames).to(accelerator.device) + video_tokens = vq_model_image.get_code(video_tensor) + len(uni_prompting.text_tokenizer) + video_tokens = video_tokens.view(1, -1) # Flatten tokens + + sptids = uni_prompting.sptids_dict + device = unwrapped_model.device + + prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' + prompt_tensor = uni_prompting.text_tokenizer(prompt_text, return_tensors="pt").input_ids.to(device) + + input_ids = torch.cat([ + sptids['<|v2t|>'].to(device).unsqueeze(0), + sptids['<|soi|>'].to(device).unsqueeze(0), + video_tokens, + sptids['<|eoi|>'].to(device).unsqueeze(0), + sptids['<|sot|>'].to(device).unsqueeze(0), + prompt_tensor + ], dim=1).long() + + output_ids = unwrapped_model.mmu_generate(input_ids, max_new_tokens=256, steps=256, block_length=128) + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + print(text) + # 3. Log result + results_table.add_data(file_name, question, text) + + # except Exception as e: + # logger.error(f"Error processing video {file_name}: {e}") + + accelerator.log({"eval/v2t_qualitative_samples": results_table}, step=global_step) + + +################################################################################################ +# +++++++++++++++++++++++++++++++++++++ MAIN EVALUATION ORCHESTRATOR +++++++++++++++++++++++++++++ +################################################################################################ + +def run_evaluation(model, vq_model_image, vq_model_audio, uni_prompting, config, accelerator, global_step): + """ + Orchestrates the S2T, T2S, and V2T evaluations. + """ + if accelerator.is_main_process: + logger.info(f"--- Starting evaluation at step {global_step} ---") + model.eval() + + if accelerator.is_main_process: + evaluate_s2t(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_t2s(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_t2s_mmu_like(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + # evaluate_t2s_fixed(model, vq_model_audio, uni_prompting, config, accelerator, global_step) + evaluate_v2t(model, vq_model_image, uni_prompting, config, accelerator, global_step) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"--- Finished evaluation at step {global_step}. Returning to training. ---") + model.train() + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) + total_batch_size = ( + (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + # Resize Vocab size for Audio Modality + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.omada.new_vocab_size}") + + if resized_vocab_size == config.model.omada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size = ( + (config.training.batch_size_t2s + config.training.batch_size_s2t +config.training.batch_size_v2t) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Video Dataset + video_captioning_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + num_frames=8, + ) + + sampler_v2t = DistributedSampler( + video_captioning_dataset, + shuffle=True, # Should be true for training + drop_last=True + ) + + train_dataloader_v2t = DataLoader( + video_captioning_dataset, + batch_size=config.training.batch_size_v2t, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_video_caption, + sampler = sampler_v2t, + drop_last=True, + ) + + # Speech Dataset + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + + logger.info(f"Dataset Prepared.") + + # Use distinct DistributedSamplers for each speech dataloader to avoid iterator interference + if accelerator.num_processes > 1: + sampler_s2t = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + sampler_t2s = DistributedSampler( + dataset_sm, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + drop_last=True, + ) + else: + sampler_s2t = None + sampler_t2s = None + + train_dataloader_s2t = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_s2t, + shuffle=False, + sampler=sampler_s2t, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + train_dataloader_t2s = DataLoader( + dataset_sm, + batch_size=config.training.batch_size_t2s, + shuffle=False, + sampler=sampler_t2s, + collate_fn=collate_fn_audio, + num_workers=config.dataset.params.num_workers, + drop_last=True, + ) + + # Combine these dataloaders into a single iterable model + iterables = { + "v2t_flow": train_dataloader_v2t, + "t2s_flow": train_dataloader_t2s, + "s2t_flow": train_dataloader_s2t, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + # s2t + total_batch_size_s2t = config.training.batch_size_s2t * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_s2t = math.ceil(len(dataset_sm) / total_batch_size_s2t) + + # t2s + total_batch_size_t2s = config.training.batch_size_t2s * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_t2s = math.ceil(len(dataset_sm) / total_batch_size_t2s) + + # v2t + total_batch_size_v2t = (config.training.batch_size_v2t * accelerator.num_processes * config.training.gradient_accumulation_steps) + num_update_steps_per_epoch_v2t = math.ceil(len(video_captioning_dataset) / total_batch_size_v2t) + + + # Calculate num_train_epochs + num_update_steps_per_epoch = max(num_update_steps_per_epoch_s2t, num_update_steps_per_epoch_t2s, num_update_steps_per_epoch_v2t) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) if num_update_steps_per_epoch > 0 else 1 + + logger.info(f"len of speech: {len(dataset_sm)}") + logger.info(f"len of video: {len(video_captioning_dataset)}") + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin or safetensors not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + def _log_and_flag_failure(message: str, exc: Exception = None): + """Log preprocessing failures on both logger and accelerator console.""" + if exc is not None: + logger.exception(message) + else: + logger.error(message) + accelerator.print(message) + + def safe_audio_encode(audio_path: str, flow_name: str): + try: + tokens = vq_model_audio.encode(audio_path) + return tokens, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] {flow_name} audio encode failed " + f"for '{audio_path}': {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + def safe_video_get_code(video_tensor_sample: torch.Tensor, sample_index: int): + try: + video_token = vq_model_image.get_code(video_tensor_sample) + return video_token, None + except Exception as exc: + msg = ( + f"[Rank {accelerator.process_index}] v2t video encode failed " + f"for sample index {sample_index}: {exc}" + ) + _log_and_flag_failure(msg, exc) + return None, msg + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + # Ensure all samplers reshuffle in a rank-consistent way each epoch + try: + if isinstance(sampler_v2t, DistributedSampler): + sampler_v2t.set_epoch(epoch) + if accelerator.num_processes > 1: + if sampler_s2t is not None: + sampler_s2t.set_epoch(epoch) + if sampler_t2s is not None: + sampler_t2s.set_epoch(epoch) + except Exception: + pass + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + batch_size_t2i = 0 + batch_size_lm = 0 + batch_size_mmu = 0 + + # Synchronize skip decision across all ranks to avoid collective mismatches + local_skip = 1 if (batch is None or batch.get("v2t_flow") is None) else 0 + try: + skip_tensor = torch.tensor(local_skip, device=accelerator.device, dtype=torch.int32) + skip_sum = accelerator.reduce(skip_tensor, reduction='sum') + should_skip = skip_sum.item() > 0 + except Exception: + # Fallback if reduce isn't available for any reason + should_skip = local_skip == 1 + + if should_skip: + if accelerator.is_main_process and local_skip: + logger.warning(f"Skipping step {global_step} (batch is None or v2t_flow missing) [synced]") + continue + + batch_size_v2t = len(batch["v2t_flow"]["video"]) + batch_size_t2s = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + + logger.info(f"batch_size_v2t: {batch_size_v2t}, batch_size_t2s: {batch_size_t2s}, batch_size_s2t: {batch_size_s2t}" ) + + # print(f"Rank {accelerator.process_index} loading data...") + # print(batch["s2t_flow"]["audio_path"]) + # print(batch["v2t_flow"]['captions']) + + audio_paths_s2t, texts_s2t = batch["s2t_flow"]["audio_path"], batch["s2t_flow"]["text"] + audio_paths_t2s, texts_t2s = batch["t2s_flow"]["audio_path"], batch["t2s_flow"]["text"] + offset = speech_vocab_start + video_tensor, texts_vid = batch["v2t_flow"]["video"], batch["v2t_flow"]["captions"] + + data_time_m.update(time.time() - end) + + failure_messages = [] + step_failed = False + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for video understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + for vid_idx, video in enumerate(video_tensor): # each video is (T, C, H, W) + tokens, err = safe_video_get_code(video, vid_idx) + if err is not None: + failure_messages.append(err) + step_failed = True + break + video_token = tokens + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + if not step_failed: + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + input_ids_vid, prompt_masks_vid, labels_vid = uni_prompting((video_tokens, texts_vid), 'v2t') + # Keep trailing EOS tokens so v2t learns to emit explicit padding. + + ( + input_ids_vid, + labels_vid, + p_mask_vid, + answer_lengths_vid + ) = prepare_inputs_and_labels_for_mmu(input_ids_vid, prompt_masks_vid, labels_vid) + + input_ids_vid = input_ids_vid.to(accelerator.device, non_blocking=True) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_s2t: + tokens, err = safe_audio_encode(path, "s2t") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + # Preserve trailing EOS tokens in s2t targets for explicit prediction. + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + if not step_failed: + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_t2s: + tokens, err = safe_audio_encode(path, "t2s") + if err is not None: + failure_messages.append(err) + step_failed = True + break + tokens = tokens.to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + if not step_failed: + # Chat-style instruction formatting for T2S training + prompt = random.choice(prompt_t2s) + texts_with_prompt = [ + f"<|start_header_id|>user<|end_header_id|>\n{prompt}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + for text in texts_t2s + ] + + # input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s_ip') + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s') + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + + failure_tensor = torch.tensor(1 if step_failed else 0, device=accelerator.device, dtype=torch.int32) + failure_sum = accelerator.reduce(failure_tensor, reduction='sum') + if failure_sum.item() > 0: + if accelerator.is_main_process and failure_messages: + for msg in failure_messages: + logger.warning(f"Skipping global step {global_step} due to preprocessing failure: {msg}") + batch_time_m.reset() + data_time_m.reset() + end = time.time() + continue + + # --------------------------------------------------------------------------------- + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + max_len = max( + input_ids_vid.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1] + ) + + # 3. Pad all tensors to the max_len + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + # --------------------------------------------------------------------------------- + + input_ids = torch.cat(( + input_ids_vid, + input_ids_s2t, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_vid, + labels_s2t, + labels_t2s + ), dim=0) + + # w/o texts and images + p_mask_lm = None + p_mask_mmu = None + answer_lengths_mmu = None + t2i_masks = None + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Input ids shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels)) + + # with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_s2t=batch_size_s2t, + batch_size_t2s=batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_s2t=p_mask_s2t, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + ) + + # Gather the losses across all processes for logging (use reduce to avoid shape mismatches) + # avg_loss_t2i = accelerator.reduce(loss_t2i, reduction='mean') + # avg_loss_lm = accelerator.reduce(loss_lm, reduction='mean') + # avg_loss_mmu = accelerator.reduce(loss_mmu, reduction='mean') + + avg_loss_vid = accelerator.reduce(loss_vid, reduction='mean') + avg_loss_s2t = accelerator.reduce(loss_s2t, reduction='mean') + avg_loss_t2s = accelerator.reduce(loss_t2s, reduction='mean') + + # loss = (config.training.t2i_coeff * loss_t2i + + # config.training.lm_coeff * loss_lm + + # config.training.mmu_coeff * loss_mmu + + # config.training.vid_coeff * loss_vid + + # config.training.s2t_coeff * loss_s2t + + # config.training.t2s_coeff * loss_t2s) + + loss = (config.training.v2t_coeff * loss_vid + + config.training.s2t_coeff * loss_s2t + + config.training.t2s_coeff * loss_t2s) + + # HMM~~~~~ + avg_masking_rate = accelerator.reduce(p_mask_t2s.mean(), reduction='mean') + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + # "step_loss_t2i": avg_loss_t2i.item(), + # "step_loss_mmu": avg_loss_mmu.item(), + # "step_loss_lm": avg_loss_lm.item(), + "step_loss_vid": avg_loss_vid.item(), + "step_loss_s2t": avg_loss_s2t.item(), + "step_loss_t2s": avg_loss_t2s.item(), + "lr": lr_scheduler.get_last_lr()[0], + # "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + # f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + # f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + # f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Loss_vid: {avg_loss_vid.item():0.4f} " + f"Loss_s2t: {avg_loss_s2t.item():0.4f} " + f"Loss_t2s: {avg_loss_t2s.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # ++++++++++++++++++++++ RUN EVALUATION +++++++++++++++++++++++++ + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + if global_step == 0 or (global_step + 1) % config.experiment.get("eval_every", 5000) == 0: + run_evaluation( + model=accelerator.unwrap_model(model), + vq_model_image=vq_model_image, + vq_model_audio=vq_model_audio, + uni_prompting=uni_prompting, + config=config, + accelerator=accelerator, + global_step=global_step + 1 + ) + # Evaluation function sets model back to train mode internally + # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + global_step += 1 + + if global_step >= config.training.max_train_steps: + break + + if global_step >= config.training.max_train_steps: + break + + accelerator.wait_for_everyone() + + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + +@torch.no_grad() +def visualize_predictions(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def generate_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +@torch.no_grad() +def understanding_images(*args, **kwargs): + # This function is not called in the main loop but kept for compatibility + pass + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_omada_stage1.py b/MMaDA/training/train_omada_stage1.py new file mode 100644 index 0000000000000000000000000000000000000000..d89d2c91b4292d256d59fb3c414aa0f209650c31 --- /dev/null +++ b/MMaDA/training/train_omada_stage1.py @@ -0,0 +1,1458 @@ +# Copyright 2025 AIDAS Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ["TOKENIZERS_PARALLELISM"] = "true" +import json +import pandas +import logging +import math +import shutil +import time +import cv2 +import glob +import random +from tqdm import tqdm +from pathlib import Path +from typing import Union +import csv +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +import wandb +import torch +from torch.optim import AdamW +from lightning.pytorch.utilities import CombinedLoader + +from transformers import AutoTokenizer, AutoConfig +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +# +++++ I2I-specific Imports +++++ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +# ++++++++++++++++++++++++++++++ + +# +++++ Omni-modal-specific Imports +++++ +from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, DistributedSampler +from tqdm.auto import tqdm +from training.data import SpeechTextDataset, MixedSpeechTextDataset, load_video_mp4, VideoCaptionDataset, S2T_INSTRUCTION, T2S_INSTRUCTION +# import librosa + +from training.data import Text2ImageDataset +from training.utils import get_config, flatten_omega_conf, image_transform +from training.imagenet_dataset import ImageNetDataset +from parquet import RefinedWebDataset, ChatDataset + +from models import MAGVITv2, get_mask_schedule, OMadaModelLM, OMadaConfig +from training.prompting_utils import UniversalPrompting +from models.lr_schedulers import get_scheduler +from models.logging import set_verbosity_info, set_verbosity_error + +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + + +SYSTEM_PROMPT_LEN = 28 + +from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter + +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False + +logger = get_logger(__name__, log_level="INFO") + +def pad_tensor(tensor, length, value): + pad_size = length - tensor.shape[1] + if pad_size <= 0: + return tensor + # Pad on the right side of the sequence (last dimension) + return torch.nn.functional.pad(tensor, (0, pad_size), "constant", value) + +def pad_answer_lengths(ans: torch.Tensor, length: int) -> torch.Tensor: + b, l = ans.shape + if l >= length: + return ans + pad_block = ans[:, :1].expand(b, length - l) + return torch.cat([ans, pad_block], dim=1) + +def resize_vocab(model, config): + logger.info(f"Resizing token embeddings to {config.model.omada.new_vocab_size}") + model.resize_token_embeddings(config.model.omada.new_vocab_size) + +def get_vq_model_class(model_type): + if model_type == "magvitv2": + return MAGVITv2 + elif model_type == "emova": + return EMOVASpeechTokenizer.from_pretrained( + "Emova-ollm/emova_speech_tokenizer_hf" + ) + else: + raise ValueError(f"model_type {model_type} not supported.") + +def collate_fn_audio(batch): + # In this setup, the tokenizer handles batching of audio paths + return { + 'audio_path': [item['audio_path'] for item in batch], + 'text': [item['text'] for item in batch], + } + +def collate_fn_video_caption(batch): + frame_list = [] + input_ids_list = [] + for item in batch: + frame_tensor = torch.stack(item['video'], dim=0) # (T, C, H, W) + frame_list.append(frame_tensor) + input_ids_list.append(item['caption']) + + frames = torch.stack(frame_list, dim=0) # (B, T, C, H, W) + #input_ids = torch.stack(input_ids_list, dim=0) + + return { + "video": frames, # torch tensor (B, T, C, H, W) + "captions": input_ids_list # input_ids (B, seq_len) + } + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + project_dir=config.experiment.logging_dir, + split_batches=True, + ) + + total_batch_size_per_gpu = (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) + total_batch_size = ( + (config.training.batch_size_t2i + + config.training.batch_size_lm + + config.training.batch_size_mmu + + config.training.batch_size_v2t + + config.training.batch_size_s2t + + config.training.batch_size_t2s) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + total_batch_size_per_gpu + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + set_verbosity_info() + else: + set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.wandb.resume + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + dir = config.experiment.logging_dir, + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + tokenizer = AutoTokenizer.from_pretrained(config.model.omada.tokenizer_path, padding_side="left") + + uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, max_audio_len=config.dataset.preprocessing.max_aud_length, + special_tokens=( + "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", + "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", + # Omada Special Tokens + "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", + ), + ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True) + + print('special tokens : \n', uni_prompting.sptids_dict) + + speech_vocab_start = len(uni_prompting.text_tokenizer) + int(config.model.omada.codebook_size) + audio_codebook_size = max(int(config.model.omada.new_vocab_size) - speech_vocab_start, 0) + t2s_special_token_ids = { + "eoa": int(uni_prompting.sptids_dict['<|eoa|>'][0].item()), + "eos": int(uni_prompting.text_tokenizer.eos_token_id), + } + + # VQ model for processing image into discrete tokens + vq_model_image = get_vq_model_class(config.model.vq_model_image.type) + if config.model.vq_model_image.get("pretrained_model_path", None): + vq_model_image = vq_model_image().to(accelerator.device) + state_dict = torch.load(config.model.vq_model_image.pretrained_model_path)['model'] + vq_model_image.load_state_dict(state_dict) + else: + vq_model_image = vq_model_image.from_pretrained(config.model.vq_model_image.vq_model_name).to(accelerator.device) + + vq_model_audio = get_vq_model_class(config.model.vq_model_audio.type) + vq_model_audio = vq_model_audio.from_pretrained(config.model.vq_model_audio.vq_model_name).to(accelerator.device) + + vq_model_image.eval() + vq_model_image.requires_grad_(False) + + vq_model_audio.eval() + vq_model_audio.requires_grad_(False) + + model = OMadaModelLM.from_pretrained(config.model.omada.pretrained_model_path, torch_dtype=torch.bfloat16).to(accelerator.device) + + # Resize Vocab size for Audio Modality + unwrapped_model = accelerator.unwrap_model(model) + original_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info("="*50) + logger.info(f"Calling resize_vocab...") + logger.info(f"Vocab size BEFORE resizing: {original_vocab_size}") + + resize_vocab(unwrapped_model, config) + + resized_vocab_size = unwrapped_model.get_input_embeddings().weight.shape[0] + logger.info(f"Vocab size AFTER resizing: {resized_vocab_size}") + logger.info(f"Config 'new_vocab_size': {config.model.omada.new_vocab_size}") + + if resized_vocab_size == config.model.omada.new_vocab_size: + logger.info("āœ… Vocab resize successful!") + else: + logger.info("āŒ Vocab resize FAILED or did not match config!") + logger.info("="*50) + mask_id = model.config.mask_token_id + + ################################## + # Optimizer and LR scheduler # + ################################# + optimizer_config = config.optimizer.params + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": optimizer_config.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if + p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + # Create mask scheduler + if config.get("mask_schedule", None) is not None: + schedule = config.mask_schedule.schedule + args = config.mask_schedule.get("params", {}) + mask_schedule = get_mask_schedule(schedule, **args) + else: + mask_schedule = get_mask_schedule(config.training.get("mask_schedule", "cosine")) + + # lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps + # max_train_steps_for_scheduler = config.training.max_train_steps + + # lr_warmup_steps_for_scheduler = config.lr_scheduler.params.warmup_steps * accelerator.num_processes + # max_train_steps_for_scheduler = config.training.max_train_steps * accelerator.num_processes + + # lr_scheduler = get_scheduler( + # config.lr_scheduler.scheduler, + # optimizer=optimizer, + # num_warmup_steps=lr_warmup_steps_for_scheduler, + # num_training_steps=max_train_steps_for_scheduler, + # min_lr_scale=config.lr_scheduler.params.min_lr_scale + # ) + + ################################## + # DATALOADER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + # total_batch_size_without_accum = config.training.batch_size_t2s * accelerator.num_processes + total_batch_size = ( + (config.training.batch_size_t2s + config.training.batch_size_s2t +config.training.batch_size_v2t) * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + + # Data for generation + if config.dataset.gen_type == "t2i": + dataset = Text2ImageDataset( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + tokenizer=None, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_t2i, + per_gpu_batch_size=config.training.batch_size_t2i, + global_batch_size=total_batch_size_t2i_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + ) + train_dataloader_t2i = dataset.train_dataloader + num_update_steps_per_epoch = math.ceil( + train_dataloader_t2i.num_batches / config.training.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + elif config.dataset.gen_type == "t2i_parquet": + # this part relies on the internal packages, which will not be released + num_update_steps_per_epoch = math.ceil(config.experiment.max_train_examples_t2i / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + train_dataloader_t2i = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_t2i_shards_path_or_url, + batch_size=config.training.batch_size_t2i, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size + ) + + elif config.dataset.gen_type == "imagenet1k": + dataset_imagenet = ImageNetDataset( + dataset_config.train_t2i_shards_path_or_url, + image_size=preproc_config.resolution, + ) + + print('process index : ', + accelerator.process_index, ', ', accelerator.num_processes, + "Length: ", len(dataset_imagenet)) + + if accelerator.num_processes > 1: + sampler = DistributedSampler(dataset_imagenet, + num_replicas=accelerator.num_processes, + rank=accelerator.process_index, + shuffle=True, + ) + shuffle = False + else: + sampler = None + shuffle = True + + train_dataloader_t2i = DataLoader(dataset_imagenet, batch_size=config.training.batch_size_t2i, + sampler=sampler, collate_fn=dataset_imagenet.collate_fn, + shuffle=shuffle, num_workers=dataset_config.num_workers) + num_update_steps_per_epoch = math.ceil(len(dataset_imagenet) / total_batch_size_t2i) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + elif config.dataset.gen_type == "pass": + pass + + else: + raise ValueError(f"Unsupported dataset type {config.dataset.type}") + + + total_batch_size_mmu_without_accum = config.training.batch_size_mmu * accelerator.num_processes + # Data for image captioning + if config.dataset.und_type == "captioning": + dataset_mmu = xr( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + tokenizer=None, # we want to get raw texts + max_seq_length=preproc_config.max_seq_length, + num_train_examples=config.experiment.max_train_examples_mmu, + per_gpu_batch_size=config.training.batch_size_mmu, + global_batch_size=total_batch_size_mmu_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + external_caption_path=dataset_config.external_caption_path, + external_journeydb_caption_path=dataset_config.external_journeydb_caption_path, + external_laion12m_caption_path=dataset_config.external_laion12m_caption_path, + external_cc12m_caption_path=dataset_config.external_cc12m_caption_path, + is_captioning=True, + add_caption_prompt=dataset_config.add_caption_prompt, + ) + train_dataloader_mmu = dataset_mmu.train_dataloader + + elif config.dataset.und_type == "captioning_parquet": + train_dataloader_mmu = create_imagetext_dataloader( + train_shards_path_or_url=dataset_config.train_mmu_shards_path_or_url, + batch_size=config.training.batch_size_mmu, + image_size=preproc_config.resolution, + num_workers=dataset_config.num_workers, + num_readers=32, + predefined_steps=num_update_steps_per_epoch, + drop_last=True, + shuffle=True, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + is_captioning=True + ) + elif config.dataset.gen_type == "pass": + pass + + else: + raise NotImplementedError(f"Unsupported dataset type {config.dataset.und_type}") + + + # Video Dataset + video_captioning_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=uni_prompting.text_tokenizer, + max_seq_length=preproc_config.max_seq_length, + resolution=preproc_config.resolution, + sample_method="uniform", + num_frames=8, + ) + + sampler_v2t = DistributedSampler( + video_captioning_dataset, + shuffle=False, + drop_last=True + ) + + train_dataloader_v2t = DataLoader( + video_captioning_dataset, + batch_size=config.training.batch_size_v2t, + num_workers=dataset_config.num_workers, + collate_fn=collate_fn_video_caption, + sampler = sampler_v2t + ) + + # Speech Dataset + dataset_sm = MixedSpeechTextDataset(config.dataset.params.audio_data) + # eval_dataset = dataset_sm + + logger.info(f"Dataset Prepared.") + + sampler_sm = DistributedSampler(dataset_sm, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True) if accelerator.num_processes > 1 else None + # sampler_vid = DistributedSampler(dataset_sm, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True) if accelerator.num_processes > 1 else None + + train_dataloader_s2t = DataLoader(dataset_sm, batch_size=config.training.batch_size_s2t, shuffle=False, sampler=sampler_sm, collate_fn=collate_fn_audio, num_workers=config.dataset.params.num_workers) + train_dataloader_t2s = DataLoader(dataset_sm, batch_size=config.training.batch_size_t2s, shuffle=False, sampler=sampler_sm, collate_fn=collate_fn_audio, num_workers=config.dataset.params.num_workers) + + # eval_dataloader = DataLoader(eval_dataset, batch_size=config.training.batch_size_s2t, collate_fn=collate_fn, num_workers=config.dataset.params.num_workers) + + # LLM pure text dataset: RefinedWeb -> pass + # dataset_lm = RefinedWebDataset(data_path=dataset_config.train_lm_shards_path_or_url, + # rank=accelerator.process_index, + # world_size=accelerator.num_processes, + # num_workers=dataset_config.num_workers) + + # train_dataloader_lm = torch.utils.data.DataLoader(dataset_lm, batch_size=config.training.batch_size_lm, + # sampler=None, collate_fn=dataset_lm.collate_fn, + # num_workers=dataset_config.num_workers) + + # Combine these dataloaders into a single iterable model + iterables = { + # "t2i_flow": train_dataloader_t2i, + # "lm_flow": train_dataloader_lm, + # "mmu_flow": train_dataloader_mmu, + "v2t_flow": train_dataloader_v2t, + "t2s_flow": train_dataloader_t2s, + "s2t_flow": train_dataloader_s2t, + } + + combined_dataloader = CombinedLoader(iterables, mode=config.dataset.combined_loader_mode) + + # num_update_steps_per_epoch = math.ceil(len(dataset_sm)*2+len(video_captioning_dataset) / total_batch_size) + # num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + # dataset_size = len(dataset_sm)*2 + len(video_captioning_dataset) + # global_batch_size = (total_batch_size_per_gpu * accelerator.num_processes * accelerator.gradient_accumulation_steps) + # num_update_steps_per_epoch = math.ceil(dataset_size / global_batch_size) + + # desired_epochs = config.training.max_train_epochs + + # config.training.max_train_steps = desired_epochs * num_update_steps_per_epoch + + # num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + # s2t + total_batch_size_s2t = config.training.batch_size_s2t * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_s2t = math.ceil(len(dataset_sm) / total_batch_size_s2t) + + # t2s + total_batch_size_t2s = config.training.batch_size_t2s * accelerator.num_processes * config.training.gradient_accumulation_steps + num_update_steps_per_epoch_t2s = math.ceil(len(dataset_sm) / total_batch_size_t2s) + + # v2t + total_batch_size_v2t = (config.training.batch_size_v2t * accelerator.num_processes * config.training.gradient_accumulation_steps) + num_update_steps_per_epoch_v2t = math.ceil(len(video_captioning_dataset) / total_batch_size_v2t) + + + # Calculate num_train_epochs + num_update_steps_per_epoch = max(num_update_steps_per_epoch_s2t, num_update_steps_per_epoch_t2s, num_update_steps_per_epoch_v2t) + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + logger.info(f"len of speech: {len(dataset_sm)}") + logger.info(f"len of video: {len(video_captioning_dataset)}") + logger.info(f"Train stpes: {config.training.max_train_steps}") + logger.info(f"Num train epochs: {num_train_epochs}") + + # sys.exit() + ################################## + # MODEL RESUME # + ################################# + global_step = 0 + first_epoch = 0 + start_step = 0 + + if config.experiment.resume_from_checkpoint: + dirs = os.listdir(config.experiment.output_dir) + logger.info(f"dirs: {dirs}") + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + logger.info(f"path: {path}") + if path is not None: + path = os.path.join(config.experiment.output_dir, path) + logger.info(f"Resuming from checkpoint: {path}") + global_step = start_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + if os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin'): + state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu") + model.load_state_dict(state_dict, strict=True) + del state_dict + elif os.path.exists(f'{path}/unwrapped_model/pytorch_model.bin.index.json'): + from safetensors.torch import load_file + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint(model, f'{path}/unwrapped_model/') + # if safetensors sharded checkpoint exists + elif os.path.exists(f'{path}/unwrapped_model/model.safetensors.index.json'): + from transformers.modeling_utils import load_sharded_checkpoint + load_sharded_checkpoint( + model, + f'{path}/unwrapped_model/', + # weight_map=None, + # load_state_dict_fn="safetensors" + ) + else: + raise FileNotFoundError(f"Checkpoint {path}/unwrapped_model/pytorch_model.bin not found") + else: + logger.info("Not resuming from checkpoint") + + ################################## + # Prepare accelerator # + ################################# + logger.info("Preparing model, optimizer and dataloaders") + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + # model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + min_lr_scale=config.lr_scheduler.params.min_lr_scale + ) + + vq_model_image.to(device=accelerator.device) + vq_model_audio.to(device=accelerator.device) + + mask_dtype = model.get_input_embeddings().weight.dtype + + ################################## + # Training # + ################################# + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + + @torch.no_grad() + def prepare_inputs_and_labels( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + texts: Union[str, str], + min_masking_rate: float = 0.0, + is_train: bool = True, + seed: int = None + ): + + image_tokens = vq_model_image.get_code(pixel_values_or_image_ids) + image_tokens = image_tokens + len(uni_prompting.text_tokenizer) + # create MLM mask and labels + input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens( + image_tokens, + mask_id, + config, + mask_schedule=mask_schedule, + is_train=is_train, + ) + input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i') + return input_ids, labels, mask_prob, image_tokens, masks + + @torch.no_grad() + def prepare_inputs_and_labels_for_text( + texts: Union[str, str], max_seq_len, eps=1e-3 + ): + # create MLM mask and labels + + input_ids_lm, prompt_mask, labels_lm = uni_prompting((texts_lm, max_seq_len), 'lm') + b, l = input_ids_lm.shape + t = torch.rand(b, device=input_ids_lm.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_lm.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_lm) + masked_indices = noisy_batch == mask_id + + return noisy_batch, labels_lm, p_mask + + # Video also uses this. + @torch.no_grad() + def prepare_inputs_and_labels_for_mmu( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + @torch.no_grad() + def prepare_inputs_and_labels_for_t2s( + input_ids_t2s, prompt_masks, labels_t2s, mask_id=126336, eps=1e-3 + ): + b, l = input_ids_t2s.shape + t = torch.rand(b, device=input_ids_t2s.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_t2s.device) < p_mask + noisy_batch = torch.where(masked_indices, mask_id, input_ids_t2s) + masked_indices = noisy_batch == mask_id + + noisy_batch[prompt_masks.bool()] = input_ids_t2s[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_t2s, p_mask, answer_lengths + + + @torch.no_grad() + def prepare_inputs_and_labels_for_s2t( + input_ids_mmu, prompt_masks, labels_mmu, eps=1e-3 + ): + b, l = input_ids_mmu.shape + t = torch.rand(b, device=input_ids_mmu.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids_mmu.device) < p_mask + # 126336 is used for [MASK] token + noisy_batch = torch.where(masked_indices, mask_id, input_ids_mmu) + masked_indices = noisy_batch == mask_id + noisy_batch[prompt_masks.bool()] = input_ids_mmu[prompt_masks.bool()] + masked_indices = noisy_batch == mask_id + + prompt_masks = prompt_masks.to(torch.int64) + answer_lengths = torch.sum((1 - prompt_masks), dim=-1, keepdim=True) + answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1]) + + return noisy_batch, labels_mmu, p_mask, answer_lengths + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for epoch in tqdm(range(first_epoch, num_train_epochs), desc="Epochs", disable=not accelerator.is_main_process, position=0): + model.train() + for batch, batch_idx, dataloader_idx in combined_dataloader: + # micro_steps_total += 1 + # for loss calculation + # batch_size_t2i = batch["t2i_flow"]["images"].shape[0] + # batch_size_lm = len(batch["lm_flow"]["input_ids"]) + # batch_size_mmu = batch["mmu_flow"]["images"].shape[0] + batch_size_t2i = 0 + batch_size_lm = 0 + batch_size_mmu = 0 + + batch_size_v2t = len(batch["v2t_flow"]["video"]) + batch_size_t2s = len(batch["t2s_flow"]["audio_path"]) + batch_size_s2t = len(batch["s2t_flow"]["audio_path"]) + + # print(f"Rank {accelerator.process_index} loading data...") + # print(batch["s2t_flow"]["audio_path"]) + # print(batch["v2t_flow"]['captions']) + + # sys.exit() + + logger.info(f"batch_size_v2t: {batch_size_v2t}, batch_size_t2s: {batch_size_t2s}, batch_size_s2t: {batch_size_s2t}" ) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for class-conditional/text-to-image generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # pixel_values, texts = batch["t2i_flow"]["images"], batch["t2i_flow"]["input_ids"] + # pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + # Omada speech + audio_paths_s2t, texts_s2t = batch["s2t_flow"]["audio_path"], batch["s2t_flow"]["text"] # + "into text" + audio_paths_t2s, texts_t2s = batch["t2s_flow"]["audio_path"], batch["t2s_flow"]["text"] # + "into speech" + + offset = speech_vocab_start + + # print(f"len(uni_prompting.text_tokenizer): {len(uni_prompting.text_tokenizer)}") + # print(f"offset: {offset}") + # sys.exit() + # Omada video + video_tensor, texts_vid = batch["v2t_flow"]["video"], batch["v2t_flow"]["captions"] + + data_time_m.update(time.time() - end) + + # # Encode images to image tokens, mask them and create input and labels + # ( + # input_ids, + # labels, + # mask_prob, + # image_tokens_ori, + # t2i_masks + # ) = prepare_inputs_and_labels(pixel_values, texts, config.training.min_masking_rate) + + # # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # # Build formatted sequences for language modeling + # # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # max_seq_len = input_ids.shape[-1] + # texts_lm = batch["lm_flow"]["input_ids"] + # ( + # input_ids_lm, + # labels_lm, + # p_mask_lm + # ) = prepare_inputs_and_labels_for_text(texts_lm, max_seq_len) + # input_ids = torch.cat((input_ids, input_ids_lm.to(input_ids.device)), dim=0) + # labels = torch.cat((labels, labels_lm.to(input_ids.device)), dim=0) + + + + # # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # # Build formatted sequences for captioning/multimodal understanding + # # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # if "llava" in config.dataset.und_type: + # pixel_values_mmu, input_ids_mmu, labels_mmu = (batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"],batch["mmu_flow"]["labels"]) + # pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + # input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + # image_tokens_mmu = vq_model_image.get_code(pixel_values_mmu) + # image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + # input_ids_mmu = torch.cat([ + # (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to( + # accelerator.device), + # (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to( + # accelerator.device), + # image_tokens_mmu, + # (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to( + # accelerator.device), + # input_ids_mmu, + # ], dim=1).long() + + # labels_mmu = torch.cat([ + # (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + # (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + # torch.ones_like(image_tokens_mmu) * uni_prompting.ignore_id, + # (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(accelerator.device), + # labels_mmu.to(accelerator.device) + # ], dim=1).long() + + # else: + # pixel_values_mmu, texts_mmu = batch["mmu_flow"]["images"], batch["mmu_flow"]["input_ids"] + # pixel_values_mmu = pixel_values_mmu.to(accelerator.device, non_blocking=True) + # image_tokens_mmu = vq_model_image.get_code(pixel_values_mmu) + # image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer) + + # input_ids_mmu, prompt_masks, labels_mmu = uni_prompting((image_tokens_mmu, texts_mmu), 'mmu') + # ( + # input_ids_mmu, + # labels_mmu, + # p_mask_mmu, + # answer_lengths_mmu + # ) = prepare_inputs_and_labels_for_mmu(input_ids_mmu, prompt_masks, labels_mmu) + # input_ids_mmu = input_ids_mmu.to(accelerator.device, non_blocking=True) + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for video understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + video_tensor = video_tensor.to(accelerator.device, non_blocking=True) + video_token_list = [] + for video in video_tensor: # each video is (T, C, H, W) + video_token = vq_model_image.get_code(video) # (T, D) + # each video is tokenized into (T, D) + video_token = video_token + len(uni_prompting.text_tokenizer) # add offset for video tokens + video_token = video_token.view(-1) # flatten to (T*D) + video_token_list.append(video_token) + + video_tokens = torch.stack(video_token_list, dim=0) # (B, T*D) + input_ids_vid, prompt_masks_vid, labels_vid = uni_prompting((video_tokens, texts_vid), 'v2t') + + ( + input_ids_vid, + labels_vid, + p_mask_vid, + answer_lengths_vid + ) = prepare_inputs_and_labels_for_mmu(input_ids_vid, prompt_masks_vid, labels_vid) + + input_ids_vid = input_ids_vid.to(accelerator.device, non_blocking=True) + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech understanding + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + prompt_s2t = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in S2T_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_s2t: + tokens = vq_model_audio.encode(path).to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + prompt = random.choice(prompt_s2t) + texts_with_prompt = [f"{prompt}{text}" for text in texts_s2t] + + input_ids_s2t, prompt_masks_s2t, labels_s2t = uni_prompting((all_audio_tokens, texts_with_prompt), 's2t') + + input_ids_s2t, labels_s2t, p_mask_s2t, answer_lengths_s2t = prepare_inputs_and_labels_for_s2t(input_ids_s2t, prompt_masks_s2t, labels_s2t) + + # print("#################S2T########################") + + # print(f"texts_with_prompt_s2t: {all_audio_tokens}") + + # print(f"prompt_masks_s2t: {prompt_masks_s2t}") + + # print(f"texts_s2t: {texts_with_prompt}") + + # print(f"input_ids_s2t: {input_ids_s2t}") + + # print(f"labels_s2t: {labels_s2t}") + + # print("############################################") + + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Build formatted sequences for speech generation + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # prompt_t2s = ['<|start_header_id|>user<|end_header_id|>\n' + prompt + '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' for prompt in T2S_INSTRUCTION] + + prompt_t2s = [prompt for prompt in T2S_INSTRUCTION] + + all_audio_tokens = [] + for path in audio_paths_t2s: + tokens = vq_model_audio.encode(path).to(accelerator.device, non_blocking=True) + tokens_with_offset = tokens + offset + all_audio_tokens.append(tokens_with_offset) + + prompt = random.choice(prompt_t2s) + texts_with_prompt = [f"{text}\n{prompt}" for text in texts_t2s] + + input_ids_t2s, prompt_masks_t2s, labels_t2s = uni_prompting((texts_with_prompt, all_audio_tokens), 't2s') + + input_ids_t2s, labels_t2s, p_mask_t2s, answer_lengths_t2s = prepare_inputs_and_labels_for_t2s(input_ids_t2s, prompt_masks_t2s, labels_t2s) + + # print("#################T2S########################") + + # print(f"texts_with_prompt_t2s: {texts_with_prompt[0]}") + + # print(f"prompt_masks_t2s: {prompt_masks_t2s[0]}") + + # print(f"all_audio_tokens: {all_audio_tokens[0]}") + + # print(f"input_ids_t2s: {input_ids_t2s[0]}") + + # print(f"labels_t2s: {labels_t2s[0]}") + + # print("############################################") + + # sys.exit() + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # Concat everything + # *-------*-------*-------*-------*-------*-------*-------*-------*-------*-------*-------* + # input_ids = torch.cat((input_ids, + # input_ids_mmu.to(input_ids.device), + # input_ids_vid.to(input_ids.device), + # input_ids_s2t.to(input_ids.device), + # input_ids_t2s.to(input_ids.device)), dim=0) + # labels = torch.cat((labels, + # labels_mmu.to(input_ids.device), + # labels_vid.to(input_ids.device), + # labels_s2t.to(input_ids.device), + # labels_t2s.to(input_ids.device)), dim=0) + # --------------------------------------------------------------------------------- + # 1. Define padding values + pad_token_id = uni_prompting.text_tokenizer.eos_token_id + + # 2. Find the maximum sequence length in the current batch + max_len = max( + input_ids_vid.shape[1], + input_ids_s2t.shape[1], + input_ids_t2s.shape[1] + ) + + # 3. Pad all tensors to the max_len + # Pad all input_ids to the same length + input_ids_vid = pad_tensor(input_ids_vid, max_len, pad_token_id) + input_ids_s2t = pad_tensor(input_ids_s2t, max_len, pad_token_id) + input_ids_t2s = pad_tensor(input_ids_t2s, max_len, pad_token_id) + + # Pad all labels to the same length + labels_vid = pad_tensor(labels_vid, max_len, -100) + labels_s2t = pad_tensor(labels_s2t, max_len, -100) + labels_t2s = pad_tensor(labels_t2s, max_len, -100) + + p_mask_vid = pad_tensor(p_mask_vid, max_len, 1.0) + p_mask_s2t = pad_tensor(p_mask_s2t, max_len, 1.0) + p_mask_t2s = pad_tensor(p_mask_t2s, max_len, 1.0) + + answer_lengths_vid = pad_answer_lengths(answer_lengths_vid, max_len) + answer_lengths_s2t = pad_answer_lengths(answer_lengths_s2t, max_len) + answer_lengths_t2s = pad_answer_lengths(answer_lengths_t2s, max_len) + # --------------------------------------------------------------------------------- + + input_ids = torch.cat(( + input_ids_vid, + input_ids_s2t, + input_ids_t2s + ), dim=0) + labels = torch.cat(( + labels_vid, + labels_s2t, + labels_t2s + ), dim=0) + + # w/o texts and images + p_mask_lm = None + p_mask_mmu = None + answer_lengths_mmu = None + t2i_masks = None + + if global_step == 0 and epoch == 0: + logger.info("Input ids: {}".format(input_ids)) + logger.info("Input ids shape: {}".format(input_ids.shape)) + logger.info("Labels: {}".format(labels)) + + # with accelerator.accumulate(model): + logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = accelerator.unwrap_model(model).forward_process( + # logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s = model.forward_process( + input_ids=input_ids, + labels=labels, + batch_size_t2i=batch_size_t2i, + batch_size_lm=batch_size_lm, + batch_size_mmu=batch_size_mmu, + batch_size_v2t=batch_size_v2t, + batch_size_s2t=batch_size_s2t, + batch_size_t2s=batch_size_t2s, + max_seq_length=config.dataset.preprocessing.max_seq_length, + p_mask_lm=p_mask_lm, + p_mask_mmu=p_mask_mmu, + p_mask_vid=p_mask_vid, + p_mask_s2t=p_mask_s2t, + p_mask_t2s=p_mask_t2s, + answer_lengths_mmu=answer_lengths_mmu, + answer_lengths_vid=answer_lengths_vid, + answer_lengths_s2t=answer_lengths_s2t, + answer_lengths_t2s=answer_lengths_t2s, + t2i_masks=t2i_masks, + t2s_vocab_start=speech_vocab_start, + t2s_codebook_size=audio_codebook_size, + t2s_special_token_ids=t2s_special_token_ids, + ) + + # Gather the losses across all processes for logging (if we use distributed training). + # avg_loss_t2i = accelerator.gather(loss_t2i.repeat(config.training.batch_size_t2i)).mean() + # avg_loss_lm = accelerator.gather(loss_lm.repeat(config.training.batch_size_lm)).mean() + # avg_loss_mmu = accelerator.gather(loss_mmu.repeat(config.training.batch_size_mmu)).mean() + + avg_loss_vid = accelerator.gather(loss_vid.repeat(config.training.batch_size_v2t)).mean() + avg_loss_s2t = accelerator.gather(loss_s2t.repeat(config.training.batch_size_s2t)).mean() + avg_loss_t2s = accelerator.gather(loss_t2s.repeat(config.training.batch_size_t2s)).mean() + + # loss = (config.training.t2i_coeff * loss_t2i + + # config.training.lm_coeff * loss_lm + + # config.training.mmu_coeff * loss_mmu + + # config.training.vid_coeff * loss_vid + + # config.training.s2t_coeff * loss_s2t + + # config.training.t2s_coeff * loss_t2s) + + loss = (config.training.v2t_coeff * loss_vid + + config.training.s2t_coeff * loss_s2t + + config.training.t2s_coeff * loss_t2s) + + # HMM~~~~~ + avg_masking_rate = accelerator.gather(p_mask_t2s.mean()).mean() + + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + batch_time_m.update(time.time() - end) + end = time.time() + + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val + ) + logs = { + # "step_loss_t2i": avg_loss_t2i.item(), + # "step_loss_mmu": avg_loss_mmu.item(), + # "step_loss_lm": avg_loss_lm.item(), + "step_loss_vid": avg_loss_vid.item(), + "step_loss_s2t": avg_loss_s2t.item(), + "step_loss_t2s": avg_loss_t2s.item(), + "lr": lr_scheduler.get_last_lr()[0], + # "avg_masking_rate": avg_masking_rate.item(), + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + # f"Loss_t2i: {avg_loss_t2i.item():0.4f} " + # f"Loss_mmu: {avg_loss_mmu.item():0.4f} " + # f"Loss_lm: {avg_loss_lm.item():0.4f} " + f"Loss_vid: {avg_loss_vid.item():0.4f} " + f"Loss_s2t: {avg_loss_s2t.item():0.4f} " + f"Loss_t2s: {avg_loss_t2s.item():0.4f} " + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, config, accelerator, global_step + 1, uni_prompting) + + # if ((global_step + 1) % config.experiment.generate_every == 0 or global_step == 0) and accelerator.is_main_process: + # generate_images( + # model, + # vq_model, + # uni_prompting, + # accelerator, + # config, + # global_step + 1, + # mask_schedule=mask_schedule, + # force_no_cfg=False + # ) + + # generate_images( + # model, + # vq_model, + # uni_prompting, + # accelerator, + # config, + # global_step + 1, + # mask_schedule=mask_schedule, + # force_no_cfg=True + # ) + + # visualize_predictions( + # model, + # vq_model, + # uni_prompting, + # config, + # global_step + 1, + # input_ids, + # image_tokens_ori, + # batch["t2i_flow"]["images"], + # texts, + # logits, + # accelerator + # ) + + # understanding_images( + # model, + # vq_model, + # uni_prompting, + # accelerator, + # config, + # global_step + 1, + # ) + + global_step += 1 + + # Stop training if max steps is reached + if global_step >= config.training.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, config, accelerator, global_step, uni_prompting) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + model.save_pretrained(config.experiment.output_dir, safe_serialization=True) + + accelerator.end_training() + + +@torch.no_grad() +def visualize_predictions( + model, + vq_model, + uni_prompting, + config, + global_step, + input_ids, + image_tokens_ori, + ori_images, + texts, + logits, + accelerator +): + logger.info("Visualizing predictions...") + model.eval() + + recons_images = vq_model.decode_code(image_tokens_ori - len(uni_prompting.text_tokenizer)) + recons_images = torch.clamp((recons_images + 1.0) / 2.0, min=0.0, max=1.0) + recons_images *= 255.0 + recons_images = recons_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + + images = torch.clamp((ori_images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predictions = logits[:config.training.batch_size_t2i, -(config.model.omada.num_vq_tokens + 1):-1:, len(uni_prompting.text_tokenizer) + config.model.omada.num_new_special_tokens: len(uni_prompting.text_tokenizer) + config.model.omada.num_new_special_tokens + config.model.omada.codebook_size] + predictions = predictions.argmax(axis=-1) + # mask_token_id = config.model.omada.vocab_size - 1 - len(uni_prompting.text_tokenizer) + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id - len(uni_prompting.text_tokenizer) + input_ids = input_ids[:config.training.batch_size_t2i, -(config.model.omada.num_vq_tokens + 1):-1:] - len(uni_prompting.text_tokenizer) + mask_ratio = list((torch.where(input_ids == mask_token_id, 1, 0).sum( + dim=-1) / config.model.omada.num_vq_tokens).cpu().numpy()) + predicted_images = torch.where(input_ids == mask_token_id, predictions, input_ids) + predicted_images = vq_model.decode_code(predicted_images) + predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0) + predicted_images *= 255.0 + predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + predicted_images = np.concatenate((images, recons_images, predicted_images), 2) + pil_images = [Image.fromarray(image) for image in predicted_images] + + # Log images + wandb_images = [wandb.Image(image, caption=f'mask ratio: {r:0.2f} \n caption: {texts[i]}') for i, (image, r) in + enumerate(zip(pil_images, mask_ratio))] + wandb.log({"Original images v.s. Reconstructed images v.s. Predicted images": wandb_images}, step=global_step) + + model.train() + + +@torch.no_grad() +def generate_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, + mask_schedule, + force_no_cfg = False +): + logger.info("Generating images...") + model.eval() + + # read validation prompts from file + with open(config.dataset.params.validation_prompts_file, "r") as f: + validation_prompts = f.read().splitlines() + + mask_dtype = model.get_input_embeddings().weight.dtype + mask_token_id = accelerator.unwrap_model(model).config.mask_token_id + image_tokens = torch.ones((len(validation_prompts), config.model.omada.num_vq_tokens), dtype=torch.long, + device=accelerator.device) * mask_token_id + input_ids, attention_mask = uni_prompting((validation_prompts, image_tokens), 't2i_gen') + if not force_no_cfg and config.training.guidance_scale > 0: + uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen') + cfg_scale = config.training.guidance_scale + else: + uncond_input_ids = None + uncond_attention_mask = None + cfg_scale = 0 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + # Generate images + gen_token_ids = accelerator.unwrap_model(model).t2i_generate( + input_ids=input_ids, + uncond_input_ids=uncond_input_ids, + attention_mask=attention_mask, + uncond_attention_mask=uncond_attention_mask, + guidance_scale=cfg_scale, + temperature=config.training.get("generation_temperature", 1.0), + timesteps=config.training.generation_timesteps, + noise_schedule=mask_schedule, + noise_type=config.training.get("noise_type", "mask"), + predict_all_tokens=config.training.get("predict_all_tokens", False), + seq_len=config.model.omada.num_vq_tokens, + uni_prompting=uni_prompting, + config=config, + ) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0) + images = vq_model.decode_code(gen_token_ids) + + model.train() + + if config.training.get("pre_encode", False): + del vq_model + + # Convert to PIL images + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=validation_prompts[i]) for i, image in enumerate(pil_images)] + wandb.log({f"Generated images with cfg {cfg_scale}": wandb_images}, step=global_step) + + + +@torch.no_grad() +def understanding_images( + model, + vq_model, + uni_prompting, + accelerator, + config, + global_step, +): + logger.info("Understanding images...") + model.eval() + + file_list = os.listdir(config.dataset.params.mmu_image_root) + file_list = [f for f in file_list if f.lower().endswith(('.jpg', '.png', '.jpeg'))] + responses = ['' for i in range(len(file_list))] + images = [] + + device = accelerator.device + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + else: + weight_dtype = torch.float32 + + for i, file_name in enumerate(file_list): + image_path = os.path.join(config.dataset.params.mmu_image_root, file_name) + image_ori = Image.open(image_path).convert("RGB") + image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device) + image = image.unsqueeze(0) + images.append(image) + image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer) + batch_size = 1 + + input_ids = uni_prompting.text_tokenizer(['<|start_header_id|>user<|end_header_id|>\n' + "Please describe this image in detail." +'<|start_header_id|>assistant<|end_header_id|>\n'])['input_ids'] + input_ids = torch.tensor(input_ids).to(device) + + input_ids = torch.cat([ + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device), + image_tokens, + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device), + (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device), + input_ids + ], dim=1).long() + with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"): + output_ids = accelerator.unwrap_model(model).mmu_generate(input_ids) + # output_ids = torch.stack(output_ids).squeeze()[None] + + text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True) + responses[i] += text[0] + model.train() + images = torch.cat(images, dim=0) + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption=responses[i]) for i, image in enumerate(pil_images)] + wandb.log({"Understanding images": wandb_images}, step=global_step) + + +def save_checkpoint(model, config, accelerator, global_step, uni_prompting): + output_dir = config.experiment.output_dir + checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None) + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpoints_total_limit: + num_to_remove = len(checkpoints) - checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + safe_serialization=True + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + # save tokenizer + uni_prompting.text_tokenizer.save_pretrained(save_path/ "unwrapped_model") + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main() diff --git a/MMaDA/training/train_v2t_inst.py b/MMaDA/training/train_v2t_inst.py new file mode 100644 index 0000000000000000000000000000000000000000..44976c02f629a6efacf8f2b82b6cbbcd7ec587e8 --- /dev/null +++ b/MMaDA/training/train_v2t_inst.py @@ -0,0 +1,88 @@ +import torch +from torch.utils.data import DataLoader +from transformers import AutoTokenizer +from torchvision import transforms +from training.data import VideoCaptionDataset +from training.utils import image_transform + +# import your VideoCaptionDataset class here +# from your_module import VideoCaptionDataset + + +def debug_print_sample(sample, name="Sample"): + if sample is None: + print(f"āš ļø {name} is None — possible missing local files.\n") + return + + print(f"\n--- {name} ---") + print("Keys:", list(sample.keys())) + + for k, v in sample.items(): + if isinstance(v, torch.Tensor): + print(f"{k}: tensor, shape={tuple(v.shape)}") + + elif isinstance(v, list): + if len(v) == 0: + print(f"{k}: empty list") + elif isinstance(v[0], torch.Tensor): + print(f"{k}: list of tensors, len={len(v)}, shape[0]={tuple(v[0].shape)}") + else: + print(f"{k}: list[{len(v)}], first={v}") + + elif isinstance(v, dict): + print(f"{k}: dict with {len(v)} keys: {list(v.keys())[:5]}...") + + else: + print(f"{k}: {v}") + +def test_llavavid_dataset(): + # dummy transform (identity) + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + # use any tokenizer (doesn't matter for test) + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + + # instantiate dataset + llavavid_dataset = VideoCaptionDataset( + transform=image_transform, + tokenizer=tokenizer, + max_seq_length=128, + dataset_name="llavavid", # āœ… target dataset + ) + openvid_dataset = VideoCaptionDataset( + transform = image_transform, + tokenizer = tokenizer , + max_seq_length=128, + dataset_name = 'openvid1m' + ) + + print(f"āœ… šŸ“Š set loaded: {len(llavavid_dataset)} total entries (approx).") + + debug_print_sample(llavavid_dataset[9000], "LLaVA-VID Sample") + debug_print_sample(openvid_dataset[10000], "OpenVid Sample") + + # optionally wrap in DataLoader + llavavid_loader = DataLoader(llavavid_dataset, batch_size=1, shuffle=False) + batch = next(iter(llavavid_loader)) + print("\nāœ… LLavavid Dataloader working. Batch keys:", batch.keys()) + + openvid_loader = DataLoader(openvid_dataset, batch_size=1, shuffle=False) + batch = next(iter(openvid_loader)) + print("\nāœ… Openvid Dataloader working. Batch keys:", batch.keys()) + +import decord +decord.bridge.set_bridge('torch') + +def check(): + try: + vr = decord.VideoReader("/home/work/AIDAS/data/video/LLaVA-Video-178K/1_2_m_nextqa/NextQA/NExTVideo/1022/8252119570.mp4") + print(f"āœ… Frames: {len(vr)}") + except Exception as e: + print(f"āŒ Error loading video: {e}") + + +if __name__ == "__main__": + check() + #test_llavavid_dataset() diff --git a/MMaDA/training/utils.py b/MMaDA/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..006720db368d7923e23a5a103aca11accbc25dd4 --- /dev/null +++ b/MMaDA/training/utils.py @@ -0,0 +1,213 @@ +import math +import random +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing import Any, List, Tuple, Union + + +################################################## +# config utils +################################################## +def get_config(): + cli_conf = OmegaConf.from_cli() + yaml_conf = OmegaConf.load(cli_conf.config) + conf = OmegaConf.merge(yaml_conf, cli_conf) + + return conf + + +def flatten_omega_conf(cfg: Any, resolve: bool = False) -> List[Tuple[str, Any]]: + ret = [] + + def handle_dict(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]: + return [(f"{key}.{k1}", v1) for k1, v1 in flatten_omega_conf(value, resolve=resolve)] + + def handle_list(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]: + return [(f"{key}.{idx}", v1) for idx, v1 in flatten_omega_conf(value, resolve=resolve)] + + if isinstance(cfg, DictConfig): + for k, v in cfg.items_ex(resolve=resolve): + if isinstance(v, DictConfig): + ret.extend(handle_dict(k, v, resolve=resolve)) + elif isinstance(v, ListConfig): + ret.extend(handle_list(k, v, resolve=resolve)) + else: + ret.append((str(k), v)) + elif isinstance(cfg, ListConfig): + for idx, v in enumerate(cfg._iter_ex(resolve=resolve)): + if isinstance(v, DictConfig): + ret.extend(handle_dict(idx, v, resolve=resolve)) + elif isinstance(v, ListConfig): + ret.extend(handle_list(idx, v, resolve=resolve)) + else: + ret.append((str(idx), v)) + else: + assert False + + return ret + + +################################################## +# training utils +################################################## +def soft_target_cross_entropy(logits, targets, soft_targets): + # ignore the first token from logits and targets (class id token) + logits = logits[:, 1:] + targets = targets[:, 1:] + + logits = logits[..., : soft_targets.shape[-1]] + + log_probs = F.log_softmax(logits, dim=-1) + padding_mask = targets.eq(-100) + + loss = torch.sum(-soft_targets * log_probs, dim=-1) + loss.masked_fill_(padding_mask, 0.0) + + # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): + num_active_elements = padding_mask.numel() - padding_mask.long().sum() + loss = loss.sum() / num_active_elements + return loss + + +def get_loss_weight(t, mask, min_val=0.3): + return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None] + + +def mask_or_random_replace_tokens(image_tokens, mask_id, config, mask_schedule, is_train=True, seed=None): + batch_size, seq_len = image_tokens.shape + + if not is_train and seed is not None: + # äæå­˜å½“å‰éšęœŗēŠ¶ę€ + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + python_rng_state = random.getstate() + + # č®¾ē½®å›ŗå®šē§å­ + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + random.seed(seed) + # print(f"Set seed to {seed}") + + if not is_train and config.training.get("eval_mask_ratios", None): + mask_prob = random.choices(config.training.eval_mask_ratios, k=batch_size) + mask_prob = torch.tensor(mask_prob, device=image_tokens.device) + else: + # Sample a random timestep for each image + timesteps = torch.rand(batch_size, device=image_tokens.device) + # Sample a random mask probability for each image using timestep and cosine schedule + mask_prob = mask_schedule(timesteps) + mask_prob = mask_prob.clip(config.training.min_masking_rate) + + # creat a random mask for each image + num_token_masked = (seq_len * mask_prob).round().clamp(min=1) + + mask_contiguous_region_prob = config.training.get("mask_contiguous_region_prob", None) + + if mask_contiguous_region_prob is None: + mask_contiguous_region = False + else: + mask_contiguous_region = random.random() < mask_contiguous_region_prob + + if not mask_contiguous_region: + batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1) + mask = batch_randperm < num_token_masked.unsqueeze(-1) + else: + resolution = int(seq_len ** 0.5) + mask = torch.zeros((batch_size, resolution, resolution), device=image_tokens.device) + + # TODO - would be nice to vectorize + for batch_idx, num_token_masked_ in enumerate(num_token_masked): + num_token_masked_ = int(num_token_masked_.item()) + + # NOTE: a bit handwavy with the bounds but gets a rectangle of ~num_token_masked_ + num_token_masked_height = random.randint( + math.ceil(num_token_masked_ / resolution), min(resolution, num_token_masked_) + ) + num_token_masked_height = min(num_token_masked_height, resolution) + + num_token_masked_width = math.ceil(num_token_masked_ / num_token_masked_height) + num_token_masked_width = min(num_token_masked_width, resolution) + + start_idx_height = random.randint(0, resolution - num_token_masked_height) + start_idx_width = random.randint(0, resolution - num_token_masked_width) + + mask[ + batch_idx, + start_idx_height: start_idx_height + num_token_masked_height, + start_idx_width: start_idx_width + num_token_masked_width, + ] = 1 + + mask = mask.reshape(batch_size, seq_len) + mask = mask.to(torch.bool) + + # mask images and create input and labels + if config.training.get("noise_type", "mask"): + input_ids = torch.where(mask, mask_id, image_tokens) + elif config.training.get("noise_type", "random_replace"): + # sample random tokens from the vocabulary + random_tokens = torch.randint_like( + image_tokens, low=0, high=config.model.codebook_size, device=image_tokens.device + ) + input_ids = torch.where(mask, random_tokens, image_tokens) + else: + raise ValueError(f"noise_type {config.training.noise_type} not supported") + + if ( + config.training.get("predict_all_tokens", False) + or config.training.get("noise_type", "mask") == "random_replace" + ): + labels = image_tokens + loss_weight = get_loss_weight(mask_prob, mask.long()) + else: + labels = torch.where(mask, image_tokens, -100) + loss_weight = None + + if not is_train and seed is not None: + # ę¢å¤éšęœŗēŠ¶ę€ + torch.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + random.setstate(python_rng_state) + + return input_ids, labels, loss_weight, mask_prob + + +################################################## +# misc +################################################## +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +from torchvision import transforms +def image_transform(image, resolution=256, normalize=True): + image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image) + image = transforms.CenterCrop((resolution, resolution))(image) + image = transforms.ToTensor()(image) + if normalize: + image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) + return image + +def image_transform_squash(image, resolution=256, normalize=True): + image = transforms.Resize((resolution,resolution), interpolation=transforms.InterpolationMode.BICUBIC)(image) + image = transforms.ToTensor()(image) + if normalize: + image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) + return image diff --git a/MMaDA/validation_prompts/imagenet_prompts.txt b/MMaDA/validation_prompts/imagenet_prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..3c09a358083bdab51e338768a86b3ec2a1e081c7 --- /dev/null +++ b/MMaDA/validation_prompts/imagenet_prompts.txt @@ -0,0 +1,24 @@ +golden retriever +tiger +wall clock +bicycle-built-for-two +coffee mug +laptop +banana +broccoli +pizza +garbage truck +red fox +horse cart +lion +acoustic guitar +egyptian cat +hummingbird +jeep +parachute +traffic light +ice cream +tree frog +mountain tent +speedboat +reflex camera \ No newline at end of file diff --git a/MMaDA/validation_prompts/math.txt b/MMaDA/validation_prompts/math.txt new file mode 100644 index 0000000000000000000000000000000000000000..bda05bed52721eabcbf3316d809b57ee178a49f0 --- /dev/null +++ b/MMaDA/validation_prompts/math.txt @@ -0,0 +1,4 @@ +A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take? +James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week? +Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them? +Toulouse has twice as many sheep as Charleston. Charleston has 4 times as many sheep as Seattle. How many sheep do Toulouse, Charleston, and Seattle have together if Seattle has 20 sheep? diff --git a/MMaDA/validation_prompts/quantative.txt b/MMaDA/validation_prompts/quantative.txt new file mode 100644 index 0000000000000000000000000000000000000000..ce882fc1ef2515be30cbff2c2d56d0ffdb17083e --- /dev/null +++ b/MMaDA/validation_prompts/quantative.txt @@ -0,0 +1,32 @@ +A 3D render of a futuristic car made of glass, driving through a city of mirrors. +A photo-realistic image of a garden with pink and blue flowers. There are pink poppies in the foreground, with their petals gently curved. The background features purple cosmos flowers. The flowers have water droplets on their petals, which glisten in the natural light. The green leaves are lush and healthy. The background is blurred, with a few trees and buildings visible. The overall image has a high resolution and is hyper-realistic, as if taken by a skilled photographer. +an egg and a bird made of wheat bread. +An armchair in the shape of an avocado +The image features a stylized stained glass illustration of a hummingbird with vibrant colors, set against a backdrop of swirling patterns and a large sun. The composition includes floral elements and intricate details, creating a vivid and dynamic scene that emphasizes the beauty of the bird. The colors range from greens to reds, enhancing the lively and artistic aesthetic of the piece. +A 3D render of a surreal explosion scene on the shore of a beautiful white sand beach with crystal clear water. The explosion has a spatter of oil paint with pastel colors and a thick consistency. The explosion is in a quiet and serene environment. A beautiful Japanese woman with a dress compacted to the sea is seen. There are butterfly petals and flowers with an ethereal glow and bioluminescence. There are pink and blue roses, and the overall image has a surreal and dreamlike quality. +A 3D render of a cute, round rice ball character with big, sparkling eyes that convey curiosity and joy. Its body is a soft, fluffy white with a slight sheen, resembling freshly cooked rice. Mochi has small, rosy cheeks that give it a warm, friendly expression. A tiny smile brightens its face, and it often sports a colorful ribbon tied around its "waist," adding a playful touch. Mochi's arms and feet are cartoonishly short, allowing it to bounce adorably around its surroundings. +A hyper-realistic close-up photograph of a woman's face, focusing on the left side. The image is highly detailed and realistic, showing voluminous glossy lips slightly parted, a well-defined nose, and open eyes with long eyelashes that cast shadows on the skin. The eye color is crystal clear almond green. The skin texture is crisp, with incredible detail of natural, lush skin and pores and freckles, with subtle highlights and shadows that give a realistic, close-up appearance. +A colorful cartoon of a tiger camouflaged in an abstract art painting, its stripes merging with the wild brushstrokes. +A 3D render of a cute, round rice ball character named Mochi, with big, sparkling eyes that convey curiosity and joy. Its body is a soft, fluffy white with a slight sheen, resembling freshly cooked rice. Mochi has small, rosy cheeks that give it a warm, friendly expression. A tiny smile brightens its face, and it often sports a colorful ribbon tied around its "waist," adding a playful touch. Mochi's arms and feet are cartoonishly short, allowing it to bounce adorably around its surroundings. This time, Mochi is placed against a background that is a vibrant explosion of colors, with bright hues of fuchsia, turquoise, lemon yellow, and emerald green creating a canvas of vibrant contrasts and playful energy. The clashing colors make Mochi's soft white body and rosy cheeks stand out even more, inviting viewers into a world of cheerful exuberance and visual delight. +The word 'mardefly' on a coffee mug. +An ancient spiritual gnomes stone pathway rock garden.sculptured . Style of alex grey,giger.unreal engine.totem sculptures.swirling patterned stone pathway courtyard.wood.stone.driftwood.statues. artistic sculpture.lanterns.air bnb tiny house +A droplet from a small brook enters the ocean, creating rippling water. A few bamboo leaves, white and green, are seen with a soft focus effect, glinting under the sun's twinkling rays. +A new Human Mecha combined, set against a massive post-apocalyptic background, realisticlighting, ruined ruins, high level of rendering, virtual reality. +A racing car with a silver transparent texture, showcasing design sensibility against a white background, industry design. +A beautiful woman facing to the camera, smiling confidently, colorful long hair, diamond necklace, deep red lip, medium shot, highly detailed, realistic, masterpiece. +Realism, Unreal Engine, cinematic feel, exaggerated lighting, cyberpunk, future world, advanced technology, neon city at night, a car chase scene, a busy road with many cars coming and going, a police car is chasing a yellow taxi, with a strong sense of speed, exaggerated lens effects, and movie screenshots. +Japanese anime, celluloid style, animation screenshots, cyberpunk, future world, technologically advanced, Close-up, A capable woman with dark blue short hair, wearing high-tech outfit, is driving, with a nervous expression. +a beautiful painting of the jungle in the morning with lots of smoke, fantasy art, matte painting +anthropomorphic crow werecreature, photograph captured in a forest +a beautiful portrait of a beautiful woman in the jungle surrounded by pink flowers, shamanism, matte painting, fantasy art +a cute cat +a dramatic volcanic eruption set against a dark blue night sky +a serene mountainous landscape featuring a calm, reflective pond in the foreground. The pond's surface mirrors the surrounding trees and rocks, creating a beautiful reflection that adds depth to the scene. On the left side, large boulders and rocky outcroppings are visible, partially submerged in the water, adding texture and interest to the composition. In the background, majestic mountains rise into the sky, their peaks adorned with patches of snow, indicating a high-altitude location. The trees, primarily coniferous, are lush and dense, their green foliage contrasting with the rocky terrain. Above, the sky is a mix of blue and white, with scattered clouds, suggesting a clear and sunny day. The overall scene is peaceful and picturesque, evoking a sense of tranquility and natural beauty. +A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background. +a majestic male lion walking through a grassy savanna +a breathtaking landscape scene featuring a vast, barren salt flat under a dramatic sky +a serene indoor setting. Dominating the foreground is a black diffuser from the brand "Aromatherapy Associates", as indicated by the white text on its silver lid. +a scene from a laboratory setting. +a t-shirt of an avocado and a llama +a man is standing on a stage, holding a microphone in his hand. +a captivating scene of two fishing boats docked at a rocky shore \ No newline at end of file diff --git a/MMaDA/validation_prompts/test.txt b/MMaDA/validation_prompts/test.txt new file mode 100644 index 0000000000000000000000000000000000000000..24ed90f483975acac0f93f5ffff11f58be00df47 --- /dev/null +++ b/MMaDA/validation_prompts/test.txt @@ -0,0 +1 @@ +A dog and a cat sleeping \ No newline at end of file diff --git a/MMaDA/validation_prompts/text2image_prompts.txt b/MMaDA/validation_prompts/text2image_prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ab329e7bdac4501411713b373556263eb03f11c --- /dev/null +++ b/MMaDA/validation_prompts/text2image_prompts.txt @@ -0,0 +1,13 @@ +a dramatic volcanic eruption set against a dark blue night sky +a serene mountainous landscape featuring a calm, reflective pond in the foreground. The pond's surface mirrors the surrounding trees and rocks, creating a beautiful reflection that adds depth to the scene. On the left side, large boulders and rocky outcroppings are visible, partially submerged in the water, adding texture and interest to the composition. In the background, majestic mountains rise into the sky, their peaks adorned with patches of snow, indicating a high-altitude location. The trees, primarily coniferous, are lush and dense, their green foliage contrasting with the rocky terrain. Above, the sky is a mix of blue and white, with scattered clouds, suggesting a clear and sunny day. The overall scene is peaceful and picturesque, evoking a sense of tranquility and natural beauty. +A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background. +a majestic male lion walking through a grassy savanna +a breathtaking landscape scene featuring a vast, barren salt flat under a dramatic sky +a family of four is captured in a moment of joy +a serene indoor setting. Dominating the foreground is a black diffuser from the brand "Aromatherapy Associates", as indicated by the white text on its silver lid. +a scene from a laboratory setting. +a t-shirt of an avocado and a llama +there is a woman who is the main subject +a man is standing on a stage, holding a microphone in his hand. +a captivating scene of two fishing boats docked at a rocky shore +a close-up of a woman's face, captured in what appears to be a mugshot setting \ No newline at end of file diff --git a/MMaDA/validation_prompts/text2speech_prompts.txt b/MMaDA/validation_prompts/text2speech_prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..6a052a6936bf9fbad9de8c76f22426a97f4050e5 --- /dev/null +++ b/MMaDA/validation_prompts/text2speech_prompts.txt @@ -0,0 +1 @@ +WHAT A SUUNY DAY \ No newline at end of file diff --git a/MMaDA/validation_prompts/text2speech_prompts_tmp.txt b/MMaDA/validation_prompts/text2speech_prompts_tmp.txt new file mode 100644 index 0000000000000000000000000000000000000000..e795100d78c63f6cd2707c39188400d102de513e --- /dev/null +++ b/MMaDA/validation_prompts/text2speech_prompts_tmp.txt @@ -0,0 +1 @@ +FUCK \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..460ffe859911ea39df69835e8f25a7b51f9192d2 --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +--- +title: AIDAS Omni Modal Diffusion +emoji: šŸ‘ +colorFrom: blue +colorTo: purple +sdk: gradio +sdk_version: 5.49.1 +app_file: app.py +pinned: false +license: cc-by-4.0 +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..6384e06a5e34d1e04883722caf1636fcdfba0135 --- /dev/null +++ b/app.py @@ -0,0 +1,109 @@ +""" +Gradio Space entrypoint mirroring `MMaDA/inference/gradio_multimodal_demo_inst.py`. +It downloads the published checkpoint once via huggingface_hub, wires it into +OmadaDemo, and launches the existing Blocks UI. + +Environment overrides: + MODEL_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion) + MODEL_REVISION (default: main) + ASSET_REPO_ID (default: jaeikkim/AIDAS-Omni-Modal-Diffusion-assets) + ASSET_REVISION (default: main) + HF_TOKEN (optional, for private model/dataset) + TRAIN_CONFIG_PATH (default: MMaDA/inference/demo/demo.yaml) + DEVICE (default: auto cuda/cpu) + PORT (default: 7860; Space sets this) +""" + +import os +import sys +from pathlib import Path + +from huggingface_hub import snapshot_download + +# Ensure local project is importable +PROJECT_ROOT = Path(__file__).resolve().parent +MMADA_ROOT = PROJECT_ROOT / "MMaDA" +if str(MMADA_ROOT) not in sys.path: + sys.path.insert(0, str(MMADA_ROOT)) + +from inference.gradio_multimodal_demo_inst import OmadaDemo, build_demo # noqa: E402 + + +def download_assets() -> Path: + """Download demo assets (logo + sample prompts/media) and return the root path.""" + repo_id = os.getenv("ASSET_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion-assets") + revision = os.getenv("ASSET_REVISION", "main") + token = os.getenv("HF_TOKEN") + cache_dir = PROJECT_ROOT / "_asset_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + + return Path( + snapshot_download( + repo_id=repo_id, + revision=revision, + repo_type="dataset", + local_dir=cache_dir, + local_dir_use_symlinks=False, + token=token, + ) + ) + + +def download_checkpoint() -> Path: + """Download checkpoint snapshot and return an `unwrapped_model` directory.""" + repo_id = os.getenv("MODEL_REPO_ID", "jaeikkim/AIDAS-Omni-Modal-Diffusion") + revision = os.getenv("MODEL_REVISION", "main") + token = os.getenv("HF_TOKEN") + cache_dir = PROJECT_ROOT / "_ckpt_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + + snapshot_path = Path( + snapshot_download( + repo_id=repo_id, + revision=revision, + repo_type="model", + local_dir=cache_dir, + local_dir_use_symlinks=False, + token=token, + ) + ) + + # If snapshot itself is unwrapped_model, return it; otherwise point a symlink to it. + if snapshot_path.name == "unwrapped_model": + return snapshot_path + nested = snapshot_path / "unwrapped_model" + if nested.is_dir(): + return nested + aliased = snapshot_path.parent / "unwrapped_model" + if not aliased.exists(): + aliased.symlink_to(snapshot_path, target_is_directory=True) + return aliased + + +def main(): + checkpoint_dir = download_checkpoint() + asset_root = download_assets() + + # Point demo assets (logo, sample prompts/media) to the downloaded dataset + from inference import gradio_multimodal_demo_inst as demo_mod # noqa: WPS433 + + demo_root = asset_root / "demo" + demo_mod.DEMO_ROOT = demo_root + demo_mod.LOGO_PATH = demo_root / "logo.png" + demo_mod.T2S_TEXT_PATH = demo_root / "t2s" / "text.txt" + demo_mod.CHAT_TEXT_PATH = demo_root / "chat" / "text.txt" + demo_mod.T2I_TEXT_PATH = demo_root / "t2i" / "text.txt" + + train_config = os.getenv( + "TRAIN_CONFIG_PATH", + str(PROJECT_ROOT / "MMaDA" / "inference" / "demo" / "demo.yaml"), + ) + device = os.getenv("DEVICE") + port = int(os.getenv("PORT", "7860")) + + app = OmadaDemo(train_config=train_config, checkpoint=str(checkpoint_dir), device=device) + build_demo(app, share=False, server_name="0.0.0.0", server_port=port) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..067d6801debc5e110f8faa375776d1137d5d0f10 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +gradio==5.49.1 +huggingface-hub +transformers==4.46.0 +diffusers==0.32.2 +omegaconf +wandb +numpy +Pillow +opencv-python-headless +soundfile +torch +torchaudio +torchvision +tqdm +scipy +einops +sentencepiece