diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..3232ed665566ec047ce55a929db1581dbda266a1 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..c0dcb6dfc97f0a41ef2eb08d20292a2c1f33d3db --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to spiritlm +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to spiritlm, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7e8b157116594e14309feac23853c2df20572fe8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,116 @@ +FAIR Noncommercial Research License +Last Updated: October 18, 2024 + +“Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement. + +“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein. + +“Documentation” means the specifications, manuals and documentation accompanying +Research Materials distributed by Meta. + +“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + +“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). + +“Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others. + +“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement. + +By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement. + + +1. License Rights and Redistribution. + + a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials. + + b. Redistribution and Use. + i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses; + + ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party. + + iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication. + + iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement. + +2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +5. Intellectual Property. + + a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. + + b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials. + +6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. + +8. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at [https://github.com/facebookresearch/spiritlm/blob/main/LICENSE]; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. + + +FAIR Acceptable Use Policy + +The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all. + +As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials. + + +Prohibited Uses + +You agree you will not use, or allow others to use, Research Materials to: + + 1. Violate the law or others’ rights, including to: + a. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: + i. Violence or terrorism + ii. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material + iii. Human trafficking, exploitation, and sexual violence + iv. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. + v. Sexual solicitation + iv. Any other criminal activity + + b. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals + + c. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services + + d. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices + + e. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws + + f. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials + + g. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system + + 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following: + + a. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State + + b. Guns and illegal weapons (including weapon development) + + c. Illegal drugs and regulated/controlled substances + + d. Operation of critical infrastructure, transportation technologies, or heavy machinery + + e. Self-harm or harm to others, including suicide, cutting, and eating disorders + + f. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual + + 3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following: + + a. Generating, promoting, or furthering fraud or the creation or promotion of disinformation + + b. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content + + c. Generating, promoting, or further distributing spam + + d. Impersonating another individual without consent, authorization, or legal right + + e. Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials o are human-generated + + f. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement + + 4. Fail to appropriately disclose to end users any known dangers of your Research Materials. + +Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform]. diff --git a/MODEL_CARD.md b/MODEL_CARD.md new file mode 100644 index 0000000000000000000000000000000000000000..d57bf311b849645756cef6f4b19e27f6b687e074 --- /dev/null +++ b/MODEL_CARD.md @@ -0,0 +1,81 @@ +# Meta Spirit LM Model Card + +## Model Details + +*Note: Use of this model is governed by the FAIR Noncommercial Research License.* + +Spirit LM is a multimodal language model that freely mixes text and speech. The model can be prompted with either text or speech and is capable of generating outputs in either modality, while preserving the expressivity of the input prompt. The model is also able to learn new tasks across modalities such as automatic speech recognition, text-to-speech, and speech classification in a few-shot manner. + +## Model Developers +Meta + +## Variations +Spirit LM comes in two versions: Spirit LM Base that uses speech phonetic tokens and Spirit LM Expressive that models expressivity using pitch and style tokens in addition to the phonetic tokens. + +## Input +Models input text or speech or a mixed sequence of the two. + +## Output +Models generate text or speech or a mixed sequence of the two. + +## Model Architecture +### Speech Tokenizer +Spirit LM uses 3 types of speech tokenizers: Phonetic Tokenizer (HuBERT), Pitch Tokenizer (VQ-VAE) and Style Tokenizer (Speechprop or Wav2vec2). We use Hifi-GAN to convert the speech tokens back to audio. + +It is worth noting that in the associated paper, for Spirit LM Expressive, we used Speechprop to extract style tokens, while we use a Wav2vec2 model to extract style tokens in this release. + +| | Model | Parameters | Input | Output | +|------------------------|--------------------------|------------|---------------------|--------------------| +| Phonetic Tokenizer | HuBERT+LinearQuantizer | 96M | Waveform | Phonetic Tokens | +| Pitch Tokenizer | VQ-VAE | 0.2M | Extracted F0 | Pitch Tokens | +| Style Tokenizer | Wav2vec2+LinearProjection| 95M | Waveform | Style Tokens | +| Base Speech Decoder | Hifi-GAN | 14M | Phonetic Tokens | Waveform | +| Expressive Speech Decoder | Hifi-GAN | 15M | Phonetic, Pitch, Style Tokens | Waveform + +### Language Model +Spirit LM is initialized from the Llama-2 7B model. + +| | Architecture | Parameters | Input/Output Tokens | Vocab Size | +|----------------------|----------------|------------|----------------------------------------------------------|------------| +| Spirit LM Base | Llama-2 7B | 7B | Text Tokens, Phonetic Tokens | 32512 | +| Spirit LM Expressive | Llama-2 7B | 7B | Text Tokens, Phonetic Tokens, Pitch Tokens, Style Tokens | 32768 | + +### Release Date +The models were trained between October and December 2023. The research paper was released on February 8th 2024. We released the model on October 18th 2024. + +### Status +This is a static model trained on an offline dataset. + +### License +We release the model under the FAIR Noncommercial Research License found in the [LICENSE](LICENSE) file in the root directory of this repo. + +### Research Paper +More information can be found in the paper ["SpiRit-LM: Interleaved Spoken and Written Language Model"](https://arxiv.org/pdf/2402.05755.pdf). + +## Hardware and Software +### Training Factors +We used custom training libraries. The training of the released models has been performed on Meta’s Research Clusters. + +The training of each model (Spirit LM Base and Spirit LM Expressive) takes 21K GPU hours of computation on hardware of type A100-80GB (TDP of 350-400W), not including the training of Llama-2. + +## Training Data +We trained the models on a combination of text-only datasets, speech-only datasets and aligned speech-text datasets. All the speech datasets are publicly available. Here are the statistics of the datasets we used: + +| | Hours | Speech Tokens | Text Tokens | +|--------------|-------|---------------|-------------| +| Speech-only | 458K | 28.2B | - | +| Speech+Text | 111K | 7.0B | 1.4B | +| Text-only | - | - | 307B | + +## Evaluation Results +See evaluations for our models and detailed ablations in Section 4 and 5, and safety evaluations in Section 6 of the [research paper](https://arxiv.org/pdf/2402.05755.pdf). + +## Intended Use +### Intended Use Cases +Spirit LM is intended for noncommercial research use in English. + +### Out-of-Scope Uses +Use in any manner that violates applicable laws or regulations (including trade compliance laws). Use in languages other than English. Use in any other way that is prohibited by the FAIR Noncommercial Research License and Acceptable Use Policy. + +## Ethical Considerations and Limitations +This model is built on Llama 2 which carries risks with use. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, Llama 2’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. The model’s speech capabilities are designed to analyze speaker agnostic qualities of any input speech and output speech in one of four pre-set voices. The model is meant for use for noncommercial research purposes only and should not be deployed in any consumer-facing applications. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2770b3ceec8468fcaba38dde38e4f2cf944c358f --- /dev/null +++ b/README.md @@ -0,0 +1,55 @@ +# Meta Spirit LM: Interleaved Spoken and Written Language Model + +This repository contains the model weights, inference code and evaluation scripts for the Spirit LM [paper](https://arxiv.org/pdf/2402.05755.pdf). You can find more generation samples on our [demo page](https://speechbot.github.io/spiritlm/). + +## Spirit LM Model Overview + + +## Installation Setup +### Conda +``` +conda env create -f env.yml +pip install -e '.[eval]' + +``` +### Pip +``` +pip install -e '.[eval]' +``` + +### Dev +(Optionally, use only if you want to run the tests.) +``` +pip install -e '.[dev]' +``` + +## Checkpoints Setup +See [checkpoints/README.md](checkpoints/README.md) + +## Quick Start +### Speech Tokenization +See [spiritlm/speech_tokenizer/README.md](spiritlm/speech_tokenizer/README.md) +### Spirit LM Generation +See [spiritlm/model/README.md](spiritlm/model/README.md) +### Speech-Text Sentiment Preservation benchmark (STSP) +See [spiritlm/eval/README.md](spiritlm/eval/README.md) + +## Model Card +More details of the model can be found in [MODEL_CARD.md](MODEL_CARD.md). + +## License +The present code is provided under the **FAIR Noncommercial Research License** found in [LICENSE](LICENSE). + +## Citation +``` +@misc{nguyen2024spiritlminterleavedspokenwritten, + title={SpiRit-LM: Interleaved Spoken and Written Language Model}, + author={Tu Anh Nguyen and Benjamin Muller and Bokai Yu and Marta R. Costa-jussa and Maha Elbayad and Sravya Popuri and Paul-Ambroise Duquenne and Robin Algayres and Ruslan Mavlyutov and Itai Gat and Gabriel Synnaeve and Juan Pino and Benoit Sagot and Emmanuel Dupoux}, + year={2024}, + eprint={2402.05755}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2402.05755}, +} +``` + diff --git a/assets/spiritlm_overview.png b/assets/spiritlm_overview.png new file mode 100644 index 0000000000000000000000000000000000000000..6c900c98877c6bbf5efb526f7ba6964d022cfd8d Binary files /dev/null and b/assets/spiritlm_overview.png differ diff --git a/checkpoints/README.md b/checkpoints/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3308a3fcd044f59cf9e79e20caa835c2719d1460 --- /dev/null +++ b/checkpoints/README.md @@ -0,0 +1,52 @@ +# Spirit LM Checkpoints + +## Download Checkpoints +To access and download Spirit LM Checkpoints, please request the model artifacts in this link: + +[https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/](https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/) + +Upon approval, you will then receive an email with download links to each model artifact. + +Please note that Spirit LM is made available under the **FAIR Noncommercial Research License** +found in the [LICENSE](../LICENSE) file in the root directory of this source tree and Acceptable Use Policy. + +## Structure +The checkpoints directory should look like this: +``` +checkpoints/ +├── README.md +├── speech_tokenizer +│   ├── hifigan_spiritlm_base +│   │   ├── config.json +│   │   ├── generator.pt +│   │   ├── speakers.txt +│   │   └── styles.txt +│   ├── hifigan_spiritlm_expressive_w2v2 +│   │   ├── config.json +│   │   ├── generator.pt +│   │   └── speakers.txt +│   ├── hubert_25hz +│   │   ├── L11_quantizer_500.pt +│   │   └── mhubert_base_25hz.pt +│   ├── style_encoder_w2v2 +│   │   ├── config.json +│   │   └── pytorch_model.bin +│   └── vqvae_f0_quantizer +│   ├── config.yaml +│   └── model.pt +└── spiritlm_model + ├── spirit-lm-base-7b + │   ├── config.json + │   ├── generation_config.json + │   ├── pytorch_model.bin + │   ├── special_tokens_map.json + │   ├── tokenizer_config.json + │   └── tokenizer.model + └── spirit-lm-expressive-7b + ├── config.json + ├── generation_config.json + ├── pytorch_model.bin + ├── special_tokens_map.json + ├── tokenizer_config.json + └── tokenizer.model +``` diff --git a/checkpoints/speech_tokenizer/hifigan_spiritlm_base/config.json b/checkpoints/speech_tokenizer/hifigan_spiritlm_base/config.json new file mode 100644 index 0000000000000000000000000000000000000000..c075413f2298990e4cdccb2d86fc60a52e5b0483 --- /dev/null +++ b/checkpoints/speech_tokenizer/hifigan_spiritlm_base/config.json @@ -0,0 +1,55 @@ +{ + + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [5,4,4,4,2], + "upsample_kernel_sizes": [11,8,8,8,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + "num_embeddings": 501, + "embedding_dim": 128, + "model_in_dim": 384, + + "segment_size": 8960, + "code_hop_size": 640, + "f0": false, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "multispkr": "from_input_file", + "num_speakers": 4, + "multistyle": "from_input_file", + "num_styles": 34, + + "dur_prediction_weight": 1.0, + "dur_predictor_params": { + "encoder_embed_dim": 128, + "var_pred_hidden_dim": 128, + "var_pred_kernel_size": 3, + "var_pred_dropout": 0.5 + }, + + "sampling_rate": 16000, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "env://" + } +} diff --git a/checkpoints/speech_tokenizer/hifigan_spiritlm_base/generator.pt b/checkpoints/speech_tokenizer/hifigan_spiritlm_base/generator.pt new file mode 100644 index 0000000000000000000000000000000000000000..ef1a253590823e40e96e54bdfac7b839de8ff084 --- /dev/null +++ b/checkpoints/speech_tokenizer/hifigan_spiritlm_base/generator.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d66c49067aeff93b14b038f143c2dc9ed981671956512d8e702897416f13c459 +size 57512631 diff --git a/checkpoints/speech_tokenizer/hifigan_spiritlm_base/speakers.txt b/checkpoints/speech_tokenizer/hifigan_spiritlm_base/speakers.txt new file mode 100644 index 0000000000000000000000000000000000000000..69b18f438a99d66f94d9ac0de545566253113da8 --- /dev/null +++ b/checkpoints/speech_tokenizer/hifigan_spiritlm_base/speakers.txt @@ -0,0 +1,4 @@ +ex04 +ex02 +ex03 +ex01 \ No newline at end of file diff --git a/checkpoints/speech_tokenizer/hifigan_spiritlm_base/styles.txt b/checkpoints/speech_tokenizer/hifigan_spiritlm_base/styles.txt new file mode 100644 index 0000000000000000000000000000000000000000..8351b70de90f7affa598d54b97a5d9fbde50b6c0 --- /dev/null +++ b/checkpoints/speech_tokenizer/hifigan_spiritlm_base/styles.txt @@ -0,0 +1,34 @@ +read-default +read-enunciated +read-confused +read-laughing +read-whisper +read-sad +read-happy +conv-projected +conv-default +conv-sympathetic +conv-fast +conv-disgusted +conv-laughing +conv-calm +conv-sarcastic +conv-whisper +conv-angry +conv-sad +conv-happy +conv-enunciated +conv-awe +read-singing +conv-confused +conv-fearful +conv-narration +conv-sleepy +conv-child +conv-animal +conv-childdir +conv-animaldir +conv-bored +conv-desire +conv-nonverbal +read-narration \ No newline at end of file diff --git a/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/config.json b/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/config.json new file mode 100644 index 0000000000000000000000000000000000000000..50a2f5d63b92670388a743aef2734d182ef2553e --- /dev/null +++ b/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/config.json @@ -0,0 +1,60 @@ +{ + + "resblock": "1", + "num_gpus": 0, + "batch_size": 128, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [5,4,4,4,2], + "upsample_kernel_sizes": [11,8,8,8,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "multispkr": "from_input_file", + "multistyle": null, + + "dur_prediction_weight": 1.0, + "dur_predictor_params": { + "encoder_embed_dim": 128, + "var_pred_hidden_dim": 128, + "var_pred_kernel_size": 3, + "var_pred_dropout": 0.5 + }, + + "segment_size": 17920, + "code_hop_size": 640, + "f0_hop_size": 1280, + "style_hop_size": 16000, + + "num_embeddings": 501, + "num_f0_tokens": 64, + "num_style_tokens": 100, + "num_speakers": 4, + + "embedding_dim": 128, + "model_in_dim": 512, + + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 16000, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "env://" + } +} \ No newline at end of file diff --git a/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/generator.pt b/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/generator.pt new file mode 100644 index 0000000000000000000000000000000000000000..b6f05b1185bcf90c85d55d8aab31433ca823196f --- /dev/null +++ b/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/generator.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9de97cbc336e6113c27f17560988a22e76df25af6a8e944ad01676aabec31326 +size 59414584 diff --git a/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/speakers.txt b/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/speakers.txt new file mode 100644 index 0000000000000000000000000000000000000000..69b18f438a99d66f94d9ac0de545566253113da8 --- /dev/null +++ b/checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/speakers.txt @@ -0,0 +1,4 @@ +ex04 +ex02 +ex03 +ex01 \ No newline at end of file diff --git a/checkpoints/speech_tokenizer/hubert_25hz/L11_quantizer_500.pt b/checkpoints/speech_tokenizer/hubert_25hz/L11_quantizer_500.pt new file mode 100644 index 0000000000000000000000000000000000000000..15289b872e57aaab429b17c1c20a798c80567bdc --- /dev/null +++ b/checkpoints/speech_tokenizer/hubert_25hz/L11_quantizer_500.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06b408c0a487f0218e8aba52ca30d9de54e0c36af8bebfd151265647d221080b +size 5222060 diff --git a/checkpoints/speech_tokenizer/hubert_25hz/mhubert_base_25hz.pt b/checkpoints/speech_tokenizer/hubert_25hz/mhubert_base_25hz.pt new file mode 100644 index 0000000000000000000000000000000000000000..3184371ae8c4c210709cebda2915e4bd5a055a35 --- /dev/null +++ b/checkpoints/speech_tokenizer/hubert_25hz/mhubert_base_25hz.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1421f060cf92b9d2ea72dcecd7f30ce549e1112e24b62b43f2ec3026301051bb +size 383333938 diff --git a/checkpoints/speech_tokenizer/style_encoder_w2v2/config.json b/checkpoints/speech_tokenizer/style_encoder_w2v2/config.json new file mode 100644 index 0000000000000000000000000000000000000000..621b9c7ecdd1bc21a1515efcd30b5a077f792fc7 --- /dev/null +++ b/checkpoints/speech_tokenizer/style_encoder_w2v2/config.json @@ -0,0 +1,321 @@ +{ + "_name_or_path": "facebook/wav2vec2-base", + "activation_dropout": 0.0, + "adapter_kernel_size": 3, + "adapter_stride": 2, + "add_adapter": false, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ForPooledSequenceClassification" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "classifier_proj_size": 256, + "codevector_dim": 256, + "contrastive_logits_temperature": 0.1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "diversity_loss_weight": 0.1, + "do_stable_layer_norm": false, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_norm": "group", + "feat_proj_dropout": 0.1, + "feat_quantizer_dropout": 0.0, + "final_dropout": 0.0, + "freeze_feat_extract_train": true, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_size": 768, + "id2label": { + "0": "0", + "1": "1", + "10": "10", + "11": "11", + "12": "12", + "13": "13", + "14": "14", + "15": "15", + "16": "16", + "17": "17", + "18": "18", + "19": "19", + "2": "2", + "20": "20", + "21": "21", + "22": "22", + "23": "23", + "24": "24", + "25": "25", + "26": "26", + "27": "27", + "28": "28", + "29": "29", + "3": "3", + "30": "30", + "31": "31", + "32": "32", + "33": "33", + "34": "34", + "35": "35", + "36": "36", + "37": "37", + "38": "38", + "39": "39", + "4": "4", + "40": "40", + "41": "41", + "42": "42", + "43": "43", + "44": "44", + "45": "45", + "46": "46", + "47": "47", + "48": "48", + "49": "49", + "5": "5", + "50": "50", + "51": "51", + "52": "52", + "53": "53", + "54": "54", + "55": "55", + "56": "56", + "57": "57", + "58": "58", + "59": "59", + "6": "6", + "60": "60", + "61": "61", + "62": "62", + "63": "63", + "64": "64", + "65": "65", + "66": "66", + "67": "67", + "68": "68", + "69": "69", + "7": "7", + "70": "70", + "71": "71", + "72": "72", + "73": "73", + "74": "74", + "75": "75", + "76": "76", + "77": "77", + "78": "78", + "79": "79", + "8": "8", + "80": "80", + "81": "81", + "82": "82", + "83": "83", + "84": "84", + "85": "85", + "86": "86", + "87": "87", + "88": "88", + "89": "89", + "9": "9", + "90": "90", + "91": "91", + "92": "92", + "93": "93", + "94": "94", + "95": "95", + "96": "96", + "97": "97", + "98": "98", + "99": "99" + }, + "initializer_range": 0.02, + "intermediate_size": 3072, + "label2id": { + "0": "0", + "1": "1", + "10": "10", + "11": "11", + "12": "12", + "13": "13", + "14": "14", + "15": "15", + "16": "16", + "17": "17", + "18": "18", + "19": "19", + "2": "2", + "20": "20", + "21": "21", + "22": "22", + "23": "23", + "24": "24", + "25": "25", + "26": "26", + "27": "27", + "28": "28", + "29": "29", + "3": "3", + "30": "30", + "31": "31", + "32": "32", + "33": "33", + "34": "34", + "35": "35", + "36": "36", + "37": "37", + "38": "38", + "39": "39", + "4": "4", + "40": "40", + "41": "41", + "42": "42", + "43": "43", + "44": "44", + "45": "45", + "46": "46", + "47": "47", + "48": "48", + "49": "49", + "5": "5", + "50": "50", + "51": "51", + "52": "52", + "53": "53", + "54": "54", + "55": "55", + "56": "56", + "57": "57", + "58": "58", + "59": "59", + "6": "6", + "60": "60", + "61": "61", + "62": "62", + "63": "63", + "64": "64", + "65": "65", + "66": "66", + "67": "67", + "68": "68", + "69": "69", + "7": "7", + "70": "70", + "71": "71", + "72": "72", + "73": "73", + "74": "74", + "75": "75", + "76": "76", + "77": "77", + "78": "78", + "79": "79", + "8": "8", + "80": "80", + "81": "81", + "82": "82", + "83": "83", + "84": "84", + "85": "85", + "86": "86", + "87": "87", + "88": "88", + "89": "89", + "9": "9", + "90": "90", + "91": "91", + "92": "92", + "93": "93", + "94": "94", + "95": "95", + "96": "96", + "97": "97", + "98": "98", + "99": "99" + }, + "layer_norm_eps": 1e-05, + "layerdrop": 0.0, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.05, + "mask_time_selection": "static", + "model_type": "wav2vec2", + "no_mask_channel_overlap": false, + "no_mask_time_overlap": false, + "num_adapter_layers": 3, + "num_attention_heads": 12, + "num_codevector_groups": 2, + "num_codevectors_per_group": 320, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 12, + "num_negatives": 100, + "output_hidden_size": 768, + "pad_token_id": 0, + "proj_codevector_dim": 256, + "tdnn_dilation": [ + 1, + 2, + 3, + 1, + 1 + ], + "tdnn_dim": [ + 512, + 512, + 512, + 512, + 1500 + ], + "tdnn_kernel": [ + 5, + 3, + 3, + 1, + 1 + ], + "torch_dtype": "float32", + "transformers_version": "4.25.1", + "use_weighted_layer_sum": false, + "vocab_size": 32, + "xvector_output_dim": 512 +} diff --git a/checkpoints/speech_tokenizer/style_encoder_w2v2/pytorch_model.bin b/checkpoints/speech_tokenizer/style_encoder_w2v2/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..c1bd0142c7bd52104ec62bb51471ebd0b8a0bdb0 --- /dev/null +++ b/checkpoints/speech_tokenizer/style_encoder_w2v2/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:819b7a87bfcf4041252c3f1825915924e10d5a06a2a66da546beaca98c7ab8bc +size 378451177 diff --git a/checkpoints/speech_tokenizer/vqvae_f0_quantizer/config.yaml b/checkpoints/speech_tokenizer/vqvae_f0_quantizer/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9808d0fe43241a02c743070c5c4e5324802932cf --- /dev/null +++ b/checkpoints/speech_tokenizer/vqvae_f0_quantizer/config.yaml @@ -0,0 +1,59 @@ + +seed: 1234 + +# Data +f0_path: '' +p_train: 0.95 +min_frames: null +batch_size: 128 +features: f0_interp,vuv +out_features: norm_f0_interp,vuv +segment_size: null +segment_multi: 16 +num_workers: 4 +vuv_scale: 2 +speaker_stats: '' +recon_loss_fn: l1_loss + + +# Optimization +learning_rate: 0.0002 +adam_b1: 0.8 +adam_b2: 0.99 +lr_decay: 0.999 +lambda_commit: 0.02 + +# VQ params +vq_params: + l_bins: 64 + emb_width: 128 + mu: 0.99 + levels: 1 + +# Encoder params +encoder_params: + input_emb_width: 2 + output_emb_width: 128 + levels: 1 + downs_t: + - 4 + strides_t: + - 2 + width: 32 + depth: 4 + m_conv: 1.0 + dilation_growth_rate: 3 + +# Decoder params +decoder_params: + input_emb_width: 2 + output_emb_width: 128 + levels: 1 + downs_t: + - 4 + strides_t: + - 2 + width: 32 + depth: 4 + m_conv: 1.0 + dilation_growth_rate: 3 diff --git a/checkpoints/speech_tokenizer/vqvae_f0_quantizer/model.pt b/checkpoints/speech_tokenizer/vqvae_f0_quantizer/model.pt new file mode 100644 index 0000000000000000000000000000000000000000..3cd828843fbde3a4b394f1418013c44e186d8a3e --- /dev/null +++ b/checkpoints/speech_tokenizer/vqvae_f0_quantizer/model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4321b0c19b47279ab3d76c4b6e85bbc439156a4dd6478919856accaf4180382 +size 2600601 diff --git a/data/examples/pred.jsonl b/data/examples/pred.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..19b6dc89a1ce717418e768fee30c1d755bdf5920 --- /dev/null +++ b/data/examples/pred.jsonl @@ -0,0 +1,5 @@ +{"pred": "angry", "id": 4792320029370491913} +{"pred": "neutral", "id": -5682350483296949563} +{"pred": "amused", "id": -8754508989367964614} +{"pred": "angry", "id": -9018665079841831624} +{"pred": "neutral", "id": 1159246029716120600} diff --git a/data/examples/ref.jsonl b/data/examples/ref.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..d4cdda45b176b3c6332b23aaba29e0137aa22323 --- /dev/null +++ b/data/examples/ref.jsonl @@ -0,0 +1,5 @@ +{"emotion": "angry", "sentiment": "negative", "wav_path": "emov/sam/Angry/anger_281-308_0286.wav", "split": "test", "speaker": "sam", "id": 4792320029370491913} +{"emotion": "neutral", "sentiment": "neutral", "wav_path": "emov/sam/Neutral/neutral_281-308_0286.wav", "split": "test", "speaker": "sam", "id": -5682350483296949563} +{"emotion": "amused", "sentiment": "positive", "wav_path": "emov/sam/Amused/amused_281-308_0286.wav", "split": "test", "speaker": "sam", "id": -8754508989367964614} +{"emotion": "angry", "sentiment": "negative", "wav_path": "emov/jenie/Angry/anger_57-84_0084.wav", "split": "test", "speaker": "jenie", "id": -9018665079841831624} +{"emotion": "neutral", "sentiment": "neutral", "wav_path": "emov/jenie/Neutral/neutral_57-84_0084.wav", "split": "test", "speaker": "jenie", "id": 1159246029716120600} diff --git a/env.yml b/env.yml new file mode 100644 index 0000000000000000000000000000000000000000..3c1952993a7d624cf3de10a284304e4890c8f5df --- /dev/null +++ b/env.yml @@ -0,0 +1,19 @@ +name: spiritlm_test +channels: + - pytorch + - nvidia +dependencies: + - python=3.9 + - pip + - pytorch-cuda=11.8 + - pytorch + - torchaudio + - pip: + - omegaconf==2.2.0 + - librosa~=0.10 + - local-attention~=1.9 + - encodec~=0.1 + - transformers + - fairscale~=0.4 + - sentencepiece + - torchfcpe~=0.0.4 \ No newline at end of file diff --git a/examples/audio/7143-88743-0029.flac b/examples/audio/7143-88743-0029.flac new file mode 100644 index 0000000000000000000000000000000000000000..ee101efc8ae986ae37ce645a431ed60b9c802d53 Binary files /dev/null and b/examples/audio/7143-88743-0029.flac differ diff --git a/examples/distributed_inference_recipe/multi_nodes.slurm b/examples/distributed_inference_recipe/multi_nodes.slurm new file mode 100644 index 0000000000000000000000000000000000000000..8511e6501bdd9edf4ed5440653a2bf9a0474ac0f --- /dev/null +++ b/examples/distributed_inference_recipe/multi_nodes.slurm @@ -0,0 +1,24 @@ +#!/bin/bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +#SBATCH --job-name=spiritlm +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=8 +#SBATCH --nodes=2 +#SBATCH --cpus-per-task=12 +#SBATCH --output=./logs/%j.stdout +#SBATCH --error=./logs/%j.stderr +#SBATCH --time=01:00:00 + +set -e + +srun bash -c 'torchrun --nnodes $SLURM_JOB_NUM_NODES --nproc-per-node $SLURM_GPUS_ON_NODE \ +--node-rank $SLURM_PROCID \ +--master-addr $(scontrol show hostnames $SLURM_NODELIST | head -n1) \ +--master-port 12345 \ +examples/distributed_inference_recipe/run_dist.py' diff --git a/examples/distributed_inference_recipe/run_dist.py b/examples/distributed_inference_recipe/run_dist.py new file mode 100644 index 0000000000000000000000000000000000000000..99d6306665f29b00e8c285989302b84ee18ccb27 --- /dev/null +++ b/examples/distributed_inference_recipe/run_dist.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +""" +Usage example: + +cd {SPIRITLM ROOT FOLDER} +export PYTHONPATH=. + +Single node, multi-gpus: +(Assume that your machine has 8 GPUs) + torchrun --nnodes 1 --nproc-per-node 8 examples/distributed_inference_recipe/run_dist.py + +Multi-nodes, multi-gpus: +(2 nodes, 8 GPUs for eahc node, via sbatch) + mkdir -p logs + sbatch examples/distributed_inference_recipe/multi_nodes.slurm +""" + +import os + +import torch +import torch.distributed as dist +import torchaudio +from spiritlm.model.spiritlm_model import ( + ContentType, + GenerationInput, + OutputModality, + Spiritlm, +) +from torch.utils.data import TensorDataset +from torch.utils.data.distributed import DistributedSampler +from transformers import GenerationConfig, set_seed + + +def run(seed: int = 0): + world_size = int(os.environ["WORLD_SIZE"]) + world_rank = int(os.environ["RANK"]) + print( + f"Running distributed inference with world_size: {world_size}, world_rank: {world_rank}" + ) + dist.init_process_group("nccl", rank=world_rank, world_size=world_size) + + set_seed(seed) + + wav = torchaudio.load("examples/audio/7143-88743-0029.flac")[0].squeeze() + + # fake repeated dataset + dataset = TensorDataset(wav.repeat(32, 1)) + + sampler = DistributedSampler(dataset=dataset) + loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=1, # don't change + sampler=sampler, + num_workers=4, + ) + + spirit_lm = Spiritlm("spirit-lm-expressive-7b") + + for _, data in enumerate(loader): + outs = spirit_lm.generate( + output_modality=OutputModality.ARBITRARY, + interleaved_inputs=[ + GenerationInput( + content=data[0], # 0 because of batch size 1 + content_type=ContentType.SPEECH, + ) + ], + generation_config=GenerationConfig( + temperature=0.9, + top_p=0.95, + max_new_tokens=200, + do_sample=True, + ), + ) + print(f"outs: {outs}") + + +def setup_env(): + os.environ["OMP_NUM_THREADS"] = "1" + + +if __name__ == "__main__": + setup_env() + run() diff --git a/examples/speech_generation/spirit_model.ipynb b/examples/speech_generation/spirit_model.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1ea11e085d6847f14af15574816209f2bb7c3390 --- /dev/null +++ b/examples/speech_generation/spirit_model.ipynb @@ -0,0 +1,625 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType\n", + "\n", + "from transformers import GenerationConfig\n", + "import IPython.display as ipd\n", + "\n", + "def display_outputs(outputs):\n", + " for output in outputs:\n", + " if output.content_type == ContentType.TEXT:\n", + " print(output.content)\n", + " else:\n", + " ipd.display(ipd.Audio(output.content, rate=16_000))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We support two variants of Spirit LM models, `Spirit LM Base` and `Spirit LM Expressive`. Both `Spirit LM Base` and `Spirit LM Expressive` are fine-tuned from the 7B Llama 2 model on text-only, speech-only and aligned speech+text datasets.\n", + "\n", + "Compared to `Spirit LM Base`, `Spirit LM Expressive` captures not only the semantics but also **expressivity** from the speech." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `Spirit LM Base`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/.conda/envs/spiritlm/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:134: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n", + " WeightNorm.apply(module, name, dim)\n" + ] + } + ], + "source": [ + "spirit_lm = Spiritlm(\"spirit-lm-base-7b\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The input `interleaved_inputs` of `generate` function is a list of either\n", + "- `GenerationInput` composed of `content_type` and `content`, or\n", + "- tuple of (`'speech'`/`'text'`, `content`)\n", + "\n", + "the inputs are interleaved following the order of the list.\n", + "\n", + "`output_modality` controls the output modality.\n", + "- If you want to generate only the text, specify it to `OutputModality.TEXT` or `'text'`;\n", + "- If you want to generate only the speech, specify it to `OutputModality.SPEECH` or `'speech'`;\n", + "- If you don't have the constraint over the generation's modality, use `OutputModality.ARBITRARY` or `'arbitrary'`;\n", + "\n", + "The output of generation is also a list (of `GenerationOuput`), when `output_modality` is `OutputModality.TEXT` or `OutputModality.SPEECH`, the list should have only one element.\n", + "When `output_modality` is `OutputModality.ARBITRARY`, the list can have multiple elements from different types (`ContentType.TEXT` or `ContentType.SPEECH`).\n", + "\n", + "The generation arguments can either be passed through `generation_config=GenerationConfig(args)` or directly in `generate(args)`.\n", + "\n", + "For a full list of generation arguments, see:\n", + "https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig\n", + "\n", + "Note that the following two commands give the same outputs:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/.conda/envs/spiritlm/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:579: UserWarning: `pad_token_id` should be positive but got -1. This will cause errors when batch generating, if there is padding. Please set `pad_token_id` explicitly as `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation\n", + " warnings.warn(\n", + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n" + ] + }, + { + "data": { + "text/plain": [ + "[GenerationOuput(content='Russia. Russia is a country that is located in the northern part of the Eurasian continent.', content_type=)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spirit_lm.generate(\n", + " interleaved_inputs=[\n", + " GenerationInput(\n", + " content=\"The largest country in the world is\",\n", + " content_type=ContentType.TEXT,\n", + " )\n", + " ],\n", + " output_modality=OutputModality.TEXT,\n", + " generation_config=GenerationConfig(\n", + " max_new_tokens=20,\n", + " do_sample=False,\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[GenerationOuput(content='Russia. Russia is a country that is located in the northern part of the Eurasian continent.', content_type=)]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spirit_lm.generate(\n", + " interleaved_inputs=[('text', \"The largest country in the world is\")],\n", + " output_modality='text',\n", + " max_new_tokens=20,\n", + " do_sample=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### T -> T generation" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A very cute cat, a black and white cat, named Meow was born to a family that was very good to her.\n", + "She\n" + ] + } + ], + "source": [ + "outputs = spirit_lm.generate(\n", + " interleaved_inputs=[('text', \"Here is a story about a cute cat named Meow:\")],\n", + " output_modality='text',\n", + " generation_config=GenerationConfig(\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_new_tokens=30,\n", + " do_sample=True,\n", + " ),\n", + ")\n", + "display_outputs(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### T -> S generation" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "outputs = spirit_lm.generate(\n", + " interleaved_inputs=[('text', \"One of the most beautiful cities in the world is\")],\n", + " output_modality='speech',\n", + " generation_config=GenerationConfig(\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_new_tokens=200,\n", + " do_sample=True,\n", + " ),\n", + ")\n", + "display_outputs(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### S -> T generation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When the `content` is speech, we accept several types:\n", + "1) The audio `Path`: e.g., `\"examples/audio/7143-88743-0029.flac\"` or `Path(\"examples/audio/7143-88743-0029.flac\")`\n", + "2) The audio `bytes`: e.g., `open(\"examples/audio/7143-88743-0029.flac\", \"rb\").read()`\n", + "3) The audio `Tensor`: e.g., `torchaudio.load(\"examples/audio/7143-88743-0029.flac\")[0].squeeze(0)`" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ipd.Audio(\"../audio/7143-88743-0029.flac\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the old man led the way to a corner of the cave where he kept his stock of skins and furs in a pile and there were\n" + ] + } + ], + "source": [ + "outputs = spirit_lm.generate(\n", + " interleaved_inputs=[('speech', \"../audio/7143-88743-0029.flac\")],\n", + " output_modality='text',\n", + " generation_config=GenerationConfig(\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_new_tokens=30,\n", + " do_sample=True,\n", + " ),\n", + ")\n", + "display_outputs(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### S -> S generation" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "outputs = spirit_lm.generate(\n", + " interleaved_inputs=[('speech', \"../audio/7143-88743-0029.flac\")],\n", + " output_modality='speech',\n", + " generation_config=GenerationConfig(\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_new_tokens=200,\n", + " do_sample=True,\n", + " ),\n", + ")\n", + "display_outputs(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Arbitrary generation" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " i want to see it he had a big knife in his hand and he cut off a strip of the skin of the ox hide and\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " good loop to hold it up well i said you are a man he cried and so you think you are and so you are now it is i am glad of that well here i\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "interleaved_outputs = spirit_lm.generate(\n", + " interleaved_inputs=[('speech', \"../audio/7143-88743-0029.flac\")],\n", + " output_modality='arbitrary',\n", + " generation_config=GenerationConfig(\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_new_tokens=200,\n", + " do_sample=True,\n", + " ),\n", + ")\n", + "display_outputs(interleaved_outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Specify the prompt by a string of tokens\n", + "\n", + "This could be useful when you construct the few-shots prompt.\n", + "\n", + "Note that when `prompt` is given, `generation_inputs` is not used." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "outputs = spirit_lm.generate(\n", + " prompt=\"[St71][Pi39][Hu99][Hu49][Pi57][Hu38][Hu149][Pi48][Hu71][Hu423][Hu427][Pi56][Hu492][Hu288][Pi44][Hu315][Hu153][Pi42][Hu389][Pi59][Hu497][Hu412][Pi51][Hu247][Hu354][Pi44][Hu7][Hu96][Pi43][Hu452][Pi0][Hu176][Hu266][Pi54][St71][Hu77][Pi13][Hu248][Hu336][Pi39][Hu211][Pi25][Hu166][Hu65][Pi58][Hu94][Hu224][Pi26][Hu148][Pi44][Hu492][Hu191][Pi26][Hu440][Pi13][Hu41][Pi20][Hu457][Hu79][Pi46][Hu382][Hu451][Pi26][Hu332][Hu216][Hu114][Hu340][St71][Pi40][Hu478][Hu74][Pi26][Hu79][Hu370][Pi56][Hu272][Hu370][Pi51][Hu53][Pi14][Hu477][Hu65][Pi46][Hu171][Hu60][Pi41][Hu258][Hu111][Pi40][Hu338][Hu23][Pi39][Hu338][Hu23][Hu338][St71][Pi57][Hu7][Hu338][Hu149][Pi59][Hu406][Hu7][Hu361][Hu99][Pi20][Hu209][Hu479][Pi35][Hu50][St71][Hu7][Hu149][Pi55][Hu35][Pi13][Hu130][Pi3][Hu169][Pi52][Hu72][Pi9][Hu434][Hu119][Hu272][Hu4][Pi20][Hu249][Hu245][Pi57][Hu433][Pi56][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][St71][Hu226][Pi32][Hu370][Hu216][Pi39][Hu459][Hu424][Pi57][Hu226][Pi46][Hu382][Hu7][Pi27][Hu58][Hu138][Pi20][Hu428][Hu397][Pi44][Hu350][Pi32][Hu306][Pi59][Hu84][Hu11][Hu171][Pi42][Hu60][Pi48][Hu314][Hu227][St71][Hu355][Pi56][Hu9][Hu58][Pi44][Hu138][Hu226][Pi25][Hu370][Hu272][Pi56][Hu382][Hu334][Pi26][Hu330][Hu176][Pi56][Hu307][Pi46][Hu145][Hu248][Pi56][Hu493][Hu64][Pi40][Hu44][Hu388][Pi39][Hu7][Hu111][Pi59][St71][Hu23][Hu481][Pi13][Hu149][Pi15][Hu80][Hu70][Pi47][Hu431][Hu457][Pi13][Hu79][Pi27][Hu249][Pi55][Hu245][Pi54][Hu433][Pi36][Hu316][Pi53][Hu180][Pi3][Hu458][Pi26][Hu86][St71][Pi43][Hu225][Pi49][Hu103][Hu60][Pi3][Hu96][Hu119][Pi39][Hu129][Pi41][Hu356][Hu218][Pi14][Hu4][Hu259][Pi56][Hu392][Pi46][Hu490][Hu75][Pi14][Hu488][Hu166][Pi46][Hu65][Hu171][Pi40][Hu60][Hu7][Hu54][Pi39][Hu85][St83][Pi40][Hu361]\",\n", + " output_modality='speech',\n", + " generation_config=GenerationConfig(\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_new_tokens=200,\n", + " do_sample=True,\n", + " ),\n", + ")\n", + "display_outputs(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `Spirit LM Expressive`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " [INFO]: device is not None, use cuda:0\n", + " [INFO] > call by:torchfcpe.tools.spawn_infer_cf_naive_mel_pe_from_pt\n", + " [WARN] args.model.use_harmonic_emb is None; use default False\n", + " [WARN] > call by:torchfcpe.tools.spawn_cf_naive_mel_pe\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/.conda/envs/spiritlm/lib/python3.10/site-packages/torchfcpe/models_infer.py:191: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " ckpt = torch.load(pt_path, map_location=torch.device(device))\n", + "Some weights of Wav2Vec2StyleEncoder were not initialized from the model checkpoint at checkpoints/speech_tokenizer/style_encoder_w2v2 and are newly initialized: ['_float_tensor']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "/home/.conda/envs/spiritlm/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:134: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n", + " WeightNorm.apply(module, name, dim)\n" + ] + } + ], + "source": [ + "spirit_lm = Spiritlm(\"spirit-lm-expressive-7b\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/.conda/envs/spiritlm/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:579: UserWarning: `pad_token_id` should be positive but got -1. This will cause errors when batch generating, if there is padding. Please set `pad_token_id` explicitly as `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "outputs = spirit_lm.generate(\n", + " interleaved_inputs=[('text', \"I am so deeply saddened, it feels as if my heart is shattering into a million pieces and I can't hold back the tears that are streaming down my face.\")],\n", + " output_modality='speech',\n", + " generation_config=GenerationConfig(\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_new_tokens=200,\n", + " do_sample=True,\n", + " ),\n", + " speaker_id=1,\n", + ")\n", + "display_outputs(outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "outputs = spirit_lm.generate(\n", + " interleaved_inputs=[('text', \"Wow!!! Congratulations!!! I'm so excited that\")],\n", + " output_modality='speech',\n", + " generation_config=GenerationConfig(\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_new_tokens=200,\n", + " do_sample=True,\n", + " ),\n", + " speaker_id=1,\n", + ")\n", + "display_outputs(outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb b/examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4ca80ebfc3dca639b81e28779beb7506224f8ccc --- /dev/null +++ b/examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb @@ -0,0 +1,563 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2798b8a9-1f91-4d0c-bc48-d6dd7b27613c", + "metadata": {}, + "source": [ + "## Load audio" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "41a0601b-102d-40dc-ac60-823a1cbb07c5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original audio:\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import IPython.display as ipd\n", + "audio = \"../audio/7143-88743-0029.flac\"\n", + "print('Original audio:')\n", + "ipd.display(ipd.Audio(audio))" + ] + }, + { + "cell_type": "markdown", + "id": "098d128f-a120-4c14-819b-8241546c98b9", + "metadata": {}, + "source": [ + "## SpiritLM tokenizers" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e8128d42-4017-46d3-9095-732be0e3fc87", + "metadata": {}, + "outputs": [], + "source": [ + "from spiritlm.speech_tokenizer import spiritlm_base, spiritlm_expressive" + ] + }, + { + "cell_type": "markdown", + "id": "9ea2b4b6-cf4b-46b1-9e6b-a8a296c48400", + "metadata": {}, + "source": [ + "### SpiritLM-BASE tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9d1d8e02-58e7-4ec6-9238-ee3bcad01a5e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/.conda/envs/spiritlm/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:134: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n", + " WeightNorm.apply(module, name, dim)\n" + ] + } + ], + "source": [ + "## Load the tokenizer\n", + "tokenizer_base = spiritlm_base()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2d42054d-ad63-4c4e-aa6a-cda8b0321673", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SpiritLM-BASE: Encode audio into units (not deduplicated) \n", + " --------------------\n", + "{'audio': 'examples/audio/7143-88743-0029.flac', 'hubert': '99 49 38 149 149 71 423 427 492 288 315 153 153 389 497 412 247 354 7 96 452 452 176 266 266 77 248 336 336 211 166 65 94 224 224 148 492 191 440 440 41 41 457 79 382 451 332 216 114 340 478 74 79 370 272 370 370 53 477 65 171 60 258 111 111 111 111 338 338 23 23 338 23 338 338 338 7 338 338 149 406 7 361 361 361 99 99 99 99 99 99 99 209 209 209 209 209 479 50 50 7 149 149 35 35 130 130 169 169 72 434 119 272 4 249 245 245 433 159 294 139 359 343 269 302 226 370 216 459 424 424 226 382 7 58 138 428 397 350 350 306 306 306 84 11 171 171 60 314 227 227 355 9 58 138 226 370 272 382 334 330 176 176 307 145 248 493 64 44 388 7 111 111 111 111 23 23 481 149 149 80 70 431 457 79 79 249 249 245 245 245 433 433 316 316 180 458 458 458 86 86 225 103 60 96 119 119 129 356 218 4 259 259 392 490 75 488 166 65 171 60 7 54 54 85 85 361 361'}\n", + "\n", + "SpiritLM-BASE: Encode audio into string (deduplicated and sorted units) \n", + " --------------------\n", + "[Hu99][Hu49][Hu38][Hu149][Hu71][Hu423][Hu427][Hu492][Hu288][Hu315][Hu153][Hu389][Hu497][Hu412][Hu247][Hu354][Hu7][Hu96][Hu452][Hu176][Hu266][Hu77][Hu248][Hu336][Hu211][Hu166][Hu65][Hu94][Hu224][Hu148][Hu492][Hu191][Hu440][Hu41][Hu457][Hu79][Hu382][Hu451][Hu332][Hu216][Hu114][Hu340][Hu478][Hu74][Hu79][Hu370][Hu272][Hu370][Hu53][Hu477][Hu65][Hu171][Hu60][Hu258][Hu111][Hu338][Hu23][Hu338][Hu23][Hu338][Hu7][Hu338][Hu149][Hu406][Hu7][Hu361][Hu99][Hu209][Hu479][Hu50][Hu7][Hu149][Hu35][Hu130][Hu169][Hu72][Hu434][Hu119][Hu272][Hu4][Hu249][Hu245][Hu433][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][Hu226][Hu370][Hu216][Hu459][Hu424][Hu226][Hu382][Hu7][Hu58][Hu138][Hu428][Hu397][Hu350][Hu306][Hu84][Hu11][Hu171][Hu60][Hu314][Hu227][Hu355][Hu9][Hu58][Hu138][Hu226][Hu370][Hu272][Hu382][Hu334][Hu330][Hu176][Hu307][Hu145][Hu248][Hu493][Hu64][Hu44][Hu388][Hu7][Hu111][Hu23][Hu481][Hu149][Hu80][Hu70][Hu431][Hu457][Hu79][Hu249][Hu245][Hu433][Hu316][Hu180][Hu458][Hu86][Hu225][Hu103][Hu60][Hu96][Hu119][Hu129][Hu356][Hu218][Hu4][Hu259][Hu392][Hu490][Hu75][Hu488][Hu166][Hu65][Hu171][Hu60][Hu7][Hu54][Hu85][Hu361]\n", + "\n", + "SpiritLM-BASE: Decode back to audio from units (not deduplicated) \n", + " --------------------\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "SpiritLM-BASE: Decode back to audio from string (deduplicated and sorted units) \n", + " --------------------\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "## encode_units\n", + "print('SpiritLM-BASE: Encode audio into units (not deduplicated) \\n', '-'*20)\n", + "units = tokenizer_base.encode_units(audio)\n", + "print(units)\n", + "\n", + "## encode_string\n", + "print('\\nSpiritLM-BASE: Encode audio into string (deduplicated and sorted units) \\n', '-'*20)\n", + "string_tokens = tokenizer_base.encode_string(audio)\n", + "print(string_tokens)\n", + "\n", + "## decode from units\n", + "print('\\nSpiritLM-BASE: Decode back to audio from units (not deduplicated) \\n', '-'*20)\n", + "resyn_wav = tokenizer_base.decode(units, speaker_id=2)\n", + "ipd.display(ipd.Audio(resyn_wav, rate=16000))\n", + "\n", + "## decode from string\n", + "print('\\nSpiritLM-BASE: Decode back to audio from string (deduplicated and sorted units) \\n', '-'*20)\n", + "resyn_dedup_wav = tokenizer_base.decode(string_tokens, speaker_id=2)\n", + "ipd.display(ipd.Audio(resyn_dedup_wav, rate=16000))" + ] + }, + { + "cell_type": "markdown", + "id": "9c0af45c-b397-4f41-b2e1-77f9199f5c3d", + "metadata": {}, + "source": [ + "### SpiritLM-EXPRESSIVE Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e03a3594-f387-4b12-aaab-dd8e56a583de", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " [INFO]: device is not None, use cuda:0\n", + " [INFO] > call by:torchfcpe.tools.spawn_infer_cf_naive_mel_pe_from_pt\n", + " [WARN] args.model.use_harmonic_emb is None; use default False\n", + " [WARN] > call by:torchfcpe.tools.spawn_cf_naive_mel_pe\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/.conda/envs/spiritlm/lib/python3.10/site-packages/torchfcpe/models_infer.py:191: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " ckpt = torch.load(pt_path, map_location=torch.device(device))\n", + "Some weights of Wav2Vec2StyleEncoder were not initialized from the model checkpoint at checkpoints/speech_tokenizer/style_encoder_w2v2 and are newly initialized: ['_float_tensor']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "/home/.conda/envs/spiritlm/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:134: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n", + " WeightNorm.apply(module, name, dim)\n" + ] + } + ], + "source": [ + "## Load the tokenizer\n", + "tokenizer_expressive = spiritlm_expressive()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1cacdbc6-8a08-470c-ae67-e7d772535f09", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SpiritLM-EXPRESSIVE: Encode audio into units (not deduplicated) \n", + " --------------------\n", + "{'audio': 'examples/audio/7143-88743-0029.flac', 'hubert': '99 49 38 149 149 71 423 427 492 288 315 153 153 389 497 412 247 354 7 96 452 452 176 266 266 77 248 336 336 211 166 65 94 224 224 148 492 191 440 440 41 41 457 79 382 451 332 216 114 340 478 74 79 370 272 370 370 53 477 65 171 60 258 111 111 111 111 338 338 23 23 338 23 338 338 338 7 338 338 149 406 7 361 361 361 99 99 99 99 99 99 99 209 209 209 209 209 479 50 50 7 149 149 35 35 130 130 169 169 72 434 119 272 4 249 245 245 433 159 294 139 359 343 269 302 226 370 216 459 424 424 226 382 7 58 138 428 397 350 350 306 306 306 84 11 171 171 60 314 227 227 355 9 58 138 226 370 272 382 334 330 176 176 307 145 248 493 64 44 388 7 111 111 111 111 23 23 481 149 149 80 70 431 457 79 79 249 249 245 245 245 433 433 316 316 180 458 458 458 86 86 225 103 60 96 119 119 129 356 218 4 259 259 392 490 75 488 166 65 171 60 7 54 54 85 85 361 361', 'pitch': '39 39 39 48 56 40 42 39 51 40 43 54 3 35 39 25 58 26 44 40 13 20 46 41 26 40 26 56 41 46 46 41 41 40 40 40 39 39 57 59 59 59 59 59 59 59 59 20 20 20 35 35 13 3 9 6 0 20 57 56 56 56 56 59 44 57 41 59 42 51 59 57 59 59 39 39 46 56 58 41 41 40 39 39 39 59 59 59 15 27 13 55 13 27 35 36 3 53 3 26 43 53 54 39 25 14 41 46 46 46 46 41 41 41', 'style': '71 71 71 71 71 71 71 71 71 83'}\n", + "\n", + "SpiritLM-EXPRESSIVE: Encode audio into string (deduplicated and sorted units) \n", + " --------------------\n", + "[St71][Pi39][Hu99][Hu49][Hu38][Hu149][Hu71][Pi48][Hu423][Hu427][Pi56][Hu492][Hu288][Pi40][Hu315][Hu153][Pi42][Hu389][Pi39][Hu497][Hu412][Pi51][Hu247][Hu354][Pi40][Hu7][Hu96][Pi43][Hu452][Pi54][Hu176][Hu266][Pi3][St71][Hu77][Pi35][Hu248][Hu336][Pi39][Hu211][Pi25][Hu166][Hu65][Pi58][Hu94][Hu224][Pi26][Hu148][Pi44][Hu492][Hu191][Pi40][Hu440][Pi13][Hu41][Pi20][Hu457][Hu79][Pi46][Hu382][Hu451][Pi41][Hu332][Hu216][Pi26][Hu114][Hu340][St71][Pi40][Hu478][Hu74][Pi26][Hu79][Hu370][Pi56][Hu272][Hu370][Pi41][Hu53][Pi46][Hu477][Hu65][Hu171][Hu60][Pi41][Hu258][Hu111][Pi40][Hu338][Hu23][Hu338][Pi39][Hu23][Hu338][St71][Pi57][Hu7][Hu338][Pi59][Hu149][Hu406][Hu7][Hu361][Hu99][Hu209][Pi20][Hu479][Hu50][St71][Pi35][Hu7][Hu149][Hu35][Pi13][Hu130][Pi3][Hu169][Pi9][Hu72][Pi6][Hu434][Hu119][Pi0][Hu272][Hu4][Pi20][Hu249][Hu245][Pi57][Hu433][Pi56][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][St71][Hu226][Pi59][Hu370][Hu216][Pi44][Hu459][Hu424][Pi57][Hu226][Pi41][Hu382][Hu7][Pi59][Hu58][Hu138][Pi42][Hu428][Hu397][Pi51][Hu350][Pi59][Hu306][Pi57][Hu84][Pi59][Hu11][Hu171][Hu60][Pi39][Hu314][Hu227][St71][Hu355][Pi46][Hu9][Hu58][Pi56][Hu138][Hu226][Pi58][Hu370][Hu272][Pi41][Hu382][Hu334][Hu330][Hu176][Pi40][Hu307][Pi39][Hu145][Hu248][Hu493][Hu64][Hu44][Hu388][Pi59][Hu7][Hu111][St71][Hu23][Pi15][Hu481][Pi27][Hu149][Pi13][Hu80][Hu70][Pi55][Hu431][Hu457][Pi13][Hu79][Pi27][Hu249][Pi35][Hu245][Pi36][Hu433][Pi3][Hu316][Pi53][Hu180][Pi3][Hu458][Pi26][Hu86][St71][Pi43][Hu225][Pi53][Hu103][Hu60][Pi54][Hu96][Hu119][Pi39][Hu129][Pi25][Hu356][Hu218][Pi14][Hu4][Hu259][Pi41][Hu392][Pi46][Hu490][Hu75][Hu488][Hu166][Hu65][Hu171][Hu60][Hu7][Pi41][Hu54][Hu85][St83][Hu361]\n", + "\n", + "SpiritLM-EXPRESSIVE: Decode back to audio from units (not deduplicated) \n", + " --------------------\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "SpiritLM-EXPRESSIVE: Decode back to audio from string (deduplicated and sorted units) \n", + " --------------------\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "## encode_units\n", + "print('SpiritLM-EXPRESSIVE: Encode audio into units (not deduplicated) \\n', '-'*20)\n", + "units = tokenizer_expressive.encode_units(audio)\n", + "print(units)\n", + "\n", + "## encode_string\n", + "print('\\nSpiritLM-EXPRESSIVE: Encode audio into string (deduplicated and sorted units) \\n', '-'*20)\n", + "string_tokens = tokenizer_expressive.encode_string(audio)\n", + "print(string_tokens)\n", + "\n", + "## decode from units\n", + "print('\\nSpiritLM-EXPRESSIVE: Decode back to audio from units (not deduplicated) \\n', '-'*20)\n", + "resyn_wav = tokenizer_expressive.decode(units, speaker_id=2)\n", + "ipd.display(ipd.Audio(resyn_wav, rate=16000))\n", + "\n", + "## decode from string\n", + "print('\\nSpiritLM-EXPRESSIVE: Decode back to audio from string (deduplicated and sorted units) \\n', '-'*20)\n", + "resyn_dedup_wav = tokenizer_expressive.decode(string_tokens, speaker_id=2)\n", + "ipd.display(ipd.Audio(resyn_dedup_wav, rate=16000))" + ] + }, + { + "cell_type": "markdown", + "id": "e46f7f7c-d26b-4ead-8d1b-fb013c1dd9d1", + "metadata": {}, + "source": [ + "## Test load each component" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d48c0595-0d25-4eab-8958-d2f30f940d88", + "metadata": {}, + "outputs": [], + "source": [ + "from spiritlm.speech_tokenizer.hubert import spiritlm_hubert\n", + "from spiritlm.speech_tokenizer.hifigan import spiritlm_base_hifigan, spiritlm_expressive_hifigan_w2v2\n", + "from spiritlm.speech_tokenizer.f0 import spiritlm_expressive_f0\n", + "from spiritlm.speech_tokenizer.style_encoder import spiritlm_expressive_style_encoder_w2v2" + ] + }, + { + "cell_type": "markdown", + "id": "a686cd17-c669-4860-8835-7e5406f0b0d9", + "metadata": {}, + "source": [ + "### Hubert Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f5a5ae21-028e-4c9f-a089-47f16b1ed1a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hubert_tokenizer(audio):\n", + " tensor([ 99, 49, 38, 149, 149, 71, 423, 427, 492, 288, 315, 153, 153, 389,\n", + " 497, 412, 247, 354, 7, 96, 452, 452, 176, 266, 266, 77, 248, 336,\n", + " 336, 211, 166, 65, 94, 224, 224, 148, 492, 191, 440, 440, 41, 41,\n", + " 457, 79, 382, 451, 332, 216, 114, 340, 478, 74, 79, 370, 272, 370,\n", + " 370, 53, 477, 65, 171, 60, 258, 111, 111, 111, 111, 338, 338, 23,\n", + " 23, 338, 23, 338, 338, 338, 7, 338, 338, 149, 406, 7, 361, 361,\n", + " 361, 99, 99, 99, 99, 99, 99, 99, 209, 209, 209, 209, 209, 479,\n", + " 50, 50, 7, 149, 149, 35, 35, 130, 130, 169, 169, 72, 434, 119,\n", + " 272, 4, 249, 245, 245, 433, 159, 294, 139, 359, 343, 269, 302, 226,\n", + " 370, 216, 459, 424, 424, 226, 382, 7, 58, 138, 428, 397, 350, 350,\n", + " 306, 306, 306, 84, 11, 171, 171, 60, 314, 227, 227, 355, 9, 58,\n", + " 138, 226, 370, 272, 382, 334, 330, 176, 176, 307, 145, 248, 493, 64,\n", + " 44, 388, 7, 111, 111, 111, 111, 23, 23, 481, 149, 149, 80, 70,\n", + " 431, 457, 79, 79, 249, 249, 245, 245, 245, 433, 433, 316, 316, 180,\n", + " 458, 458, 458, 86, 86, 225, 103, 60, 96, 119, 119, 129, 356, 218,\n", + " 4, 259, 259, 392, 490, 75, 488, 166, 65, 171, 60, 7, 54, 54,\n", + " 85, 85, 361, 361], device='cuda:0')\n" + ] + } + ], + "source": [ + "hubert_tokenizer = spiritlm_hubert()\n", + "print(\"hubert_tokenizer(audio):\\n\", hubert_tokenizer(audio))" + ] + }, + { + "cell_type": "markdown", + "id": "86a6d8f1-d6c8-43c9-9a1c-eb9bbcad3791", + "metadata": {}, + "source": [ + "### Pitch Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "21b03535-37a4-4ce5-8045-b309c2ebe51a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " [INFO]: device is not None, use cuda:0\n", + " [INFO] > call by:torchfcpe.tools.spawn_infer_cf_naive_mel_pe_from_pt\n", + " [WARN] args.model.use_harmonic_emb is None; use default False\n", + " [WARN] > call by:torchfcpe.tools.spawn_cf_naive_mel_pe\n", + "f0_tokenizer(audio):\n", + " tensor([39, 39, 39, 48, 56, 40, 42, 39, 51, 40, 43, 54, 3, 35, 39, 25, 58, 26,\n", + " 44, 40, 13, 20, 46, 41, 26, 40, 26, 56, 41, 46, 46, 41, 41, 40, 40, 40,\n", + " 39, 39, 57, 59, 59, 59, 59, 59, 59, 59, 59, 20, 20, 20, 35, 35, 13, 3,\n", + " 9, 6, 0, 20, 57, 56, 56, 56, 56, 59, 44, 57, 41, 59, 42, 51, 59, 57,\n", + " 59, 59, 39, 39, 46, 56, 58, 41, 41, 40, 39, 39, 39, 59, 59, 59, 15, 27,\n", + " 13, 55, 13, 27, 35, 36, 3, 53, 3, 26, 43, 53, 54, 39, 25, 14, 41, 46,\n", + " 46, 46, 46, 41, 41, 41], device='cuda:0')\n" + ] + } + ], + "source": [ + "f0_tokenizer = spiritlm_expressive_f0()\n", + "print(\"f0_tokenizer(audio):\\n\", f0_tokenizer(audio))" + ] + }, + { + "cell_type": "markdown", + "id": "33e2ef11-82de-42cb-be3d-16d32b5b8510", + "metadata": {}, + "source": [ + "### Style Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0eb0fef0-5140-4829-84b3-624e7baeecdd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of Wav2Vec2StyleEncoder were not initialized from the model checkpoint at checkpoints/speech_tokenizer/style_encoder_w2v2 and are newly initialized: ['_float_tensor']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "style_tokenizer(audio):\n", + " tensor([71, 71, 71, 71, 71, 71, 71, 71, 71, 83], device='cuda:0')\n" + ] + } + ], + "source": [ + "style_tokenizer = spiritlm_expressive_style_encoder_w2v2()\n", + "print(\"style_tokenizer(audio):\\n\", style_tokenizer(audio))" + ] + }, + { + "cell_type": "markdown", + "id": "8f00182f-3e7c-4ea1-b0db-3e1bf386dae5", + "metadata": {}, + "source": [ + "### Hifi-GAN Vocoders" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "422f5566-2a08-46b7-84ba-ae43c5d9cd0f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Base vocoder\n", + "base_vocoder = spiritlm_base_hifigan()\n", + "wav = base_vocoder(hubert_tokenizer(audio), speaker_id=1).cpu().numpy()\n", + "ipd.Audio(wav, rate=16000)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f82246da-0396-45c0-ab44-f83f8231ca75", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Expressive vocoder\n", + "expressive_vocoder = spiritlm_expressive_hifigan_w2v2()\n", + "wav_ex = expressive_vocoder(\n", + " code=hubert_tokenizer(audio),\n", + " f0_code=f0_tokenizer(audio),\n", + " style_code=style_tokenizer(audio),\n", + " dur_pred=False,\n", + " speaker_id=1,\n", + " not_dedup_code=True,\n", + " ).cpu().numpy()\n", + "ipd.Audio(wav_ex, rate=16000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "979c6080-25e1-4516-9482-67b364687a88", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..55b033e901cdda93a26ac64b418f260224260a39 --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1 @@ +pytest \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ba4c2c856d97f8d289405f2889f59f89bb9d0b83 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +omegaconf>=2.2.0 +librosa>=0.10 +local-attention>=1.9 +encodec>=0.1 +transformers +fairscale>=0.4 +sentencepiece +pyarrow>=14.0 +torchfcpe>=0.0.4 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..80a067e8d828b9590c700b0b717c437d1f0ae9db --- /dev/null +++ b/setup.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path + +from setuptools import find_packages, setup + +NAME = "spiritlm" +VERSION = "0.1.0" +DESCRIPTION = "Interleaved Spoken and Written Language Model" +URL = "https://github.com/facebookresearch/spiritlm" +KEYWORDS = [ + "Language Model, Speech Language Model, Multimodal, Crossmodal, Expressivity Modeling" +] +LICENSE = "FAIR Noncommercial Research License" + + +def _get_long_description(): + with (Path(__file__).parent / "README.md").open(encoding="utf-8") as file: + long_description = file.read() + return long_description + + +def _read_reqs(relpath): + fullpath = os.path.join(os.path.dirname(__file__), relpath) + with open(fullpath) as f: + return [ + s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#")) + ] + + +setup( + name=NAME, + version=VERSION, + description=DESCRIPTION, + long_description=_get_long_description(), + long_description_content_type="text/plain", + url=URL, + license=LICENSE, + author="Meta", + keywords=KEYWORDS, + classifiers=[ + "Intended Audience :: Science/Research", + "License :: FAIR Noncommercial Research License", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + packages=find_packages(), + zip_safe=False, + python_requires=">=3.9", + install_requires=_read_reqs("requirements.txt"), + extras_require={ + "dev": ["pytest"], + "eval": ["pandas"], + }, +) diff --git a/spiritlm.egg-info/PKG-INFO b/spiritlm.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..57d9dc7f754443a11f4fdc0f617fa13f4343eae5 --- /dev/null +++ b/spiritlm.egg-info/PKG-INFO @@ -0,0 +1,84 @@ +Metadata-Version: 2.1 +Name: spiritlm +Version: 0.1.0 +Summary: Interleaved Spoken and Written Language Model +Home-page: https://github.com/facebookresearch/spiritlm +Author: Meta +License: FAIR Noncommercial Research License +Keywords: Language Model, Speech Language Model, Multimodal, Crossmodal, Expressivity Modeling +Classifier: Intended Audience :: Science/Research +Classifier: License :: FAIR Noncommercial Research License +Classifier: Topic :: Multimedia :: Sound/Audio +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Requires-Python: >=3.9 +Description-Content-Type: text/plain +License-File: LICENSE +Requires-Dist: omegaconf>=2.2.0 +Requires-Dist: librosa>=0.10 +Requires-Dist: local-attention>=1.9 +Requires-Dist: encodec>=0.1 +Requires-Dist: transformers +Requires-Dist: fairscale>=0.4 +Requires-Dist: sentencepiece +Requires-Dist: pyarrow>=14.0 +Requires-Dist: torchfcpe>=0.0.4 +Provides-Extra: dev +Requires-Dist: pytest; extra == "dev" +Provides-Extra: eval +Requires-Dist: pandas; extra == "eval" + +# Meta Spirit LM: Interleaved Spoken and Written Language Model + +This repository contains the model weights, inference code and evaluation scripts for the Spirit LM [paper](https://arxiv.org/pdf/2402.05755.pdf). You can find more generation samples on our [demo page](https://speechbot.github.io/spiritlm/). + +## Spirit LM Model Overview + + +## Installation Setup +### Conda +``` +conda env create -f env.yml +pip install -e '.[eval]' + +``` +### Pip +``` +pip install -e '.[eval]' +``` + +### Dev +(Optionally, use only if you want to run the tests.) +``` +pip install -e '.[dev]' +``` + +## Checkpoints Setup +See [checkpoints/README.md](checkpoints/README.md) + +## Quick Start +### Speech Tokenization +See [spiritlm/speech_tokenizer/README.md](spiritlm/speech_tokenizer/README.md) +### Spirit LM Generation +See [spiritlm/model/README.md](spiritlm/model/README.md) +### Speech-Text Sentiment Preservation benchmark (STSP) +See [spiritlm/eval/README.md](spiritlm/eval/README.md) + +## Model Card +More details of the model can be found in [MODEL_CARD.md](MODEL_CARD.md). + +## License +The present code is provided under the **FAIR Noncommercial Research License** found in [LICENSE](LICENSE). + +## Citation +``` +@misc{nguyen2024spiritlminterleavedspokenwritten, + title={SpiRit-LM: Interleaved Spoken and Written Language Model}, + author={Tu Anh Nguyen and Benjamin Muller and Bokai Yu and Marta R. Costa-jussa and Maha Elbayad and Sravya Popuri and Paul-Ambroise Duquenne and Robin Algayres and Ruslan Mavlyutov and Itai Gat and Gabriel Synnaeve and Juan Pino and Benoit Sagot and Emmanuel Dupoux}, + year={2024}, + eprint={2402.05755}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2402.05755}, +} +``` + diff --git a/spiritlm.egg-info/SOURCES.txt b/spiritlm.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..33b103770080c88919dab40fdfb2bb7289746c9b --- /dev/null +++ b/spiritlm.egg-info/SOURCES.txt @@ -0,0 +1,32 @@ +LICENSE +README.md +setup.py +spiritlm/__init__.py +spiritlm.egg-info/PKG-INFO +spiritlm.egg-info/SOURCES.txt +spiritlm.egg-info/dependency_links.txt +spiritlm.egg-info/not-zip-safe +spiritlm.egg-info/requires.txt +spiritlm.egg-info/top_level.txt +spiritlm/model/__init__.py +spiritlm/model/spiritlm_model.py +spiritlm/model/utils.py +spiritlm/speech_tokenizer/__init__.py +spiritlm/speech_tokenizer/spiritlm_tokenizer.py +spiritlm/speech_tokenizer/f0/__init__.py +spiritlm/speech_tokenizer/f0/f0_extractor.py +spiritlm/speech_tokenizer/f0/f0_tokenizer.py +spiritlm/speech_tokenizer/f0/vqvae.py +spiritlm/speech_tokenizer/hifigan/__init__.py +spiritlm/speech_tokenizer/hifigan/hifigan_vocoder.py +spiritlm/speech_tokenizer/hubert/__init__.py +spiritlm/speech_tokenizer/hubert/hubert_tokenizer.py +spiritlm/speech_tokenizer/hubert/quantizer_model.py +spiritlm/speech_tokenizer/hubert/hubert_model/__init__.py +spiritlm/speech_tokenizer/hubert/hubert_model/hubert_model.py +spiritlm/speech_tokenizer/hubert/hubert_model/wav2vec2_model.py +spiritlm/speech_tokenizer/style_encoder/__init__.py +spiritlm/speech_tokenizer/style_encoder/w2v2_encoder.py +tests/__init__.py +tests/test_spirit_model.py +tests/test_tokenizer.py \ No newline at end of file diff --git a/spiritlm.egg-info/dependency_links.txt b/spiritlm.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/spiritlm.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/spiritlm.egg-info/not-zip-safe b/spiritlm.egg-info/not-zip-safe new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/spiritlm.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/spiritlm.egg-info/requires.txt b/spiritlm.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..e28f802f9c29a24e0588066f0ab98f8dcf307300 --- /dev/null +++ b/spiritlm.egg-info/requires.txt @@ -0,0 +1,15 @@ +omegaconf>=2.2.0 +librosa>=0.10 +local-attention>=1.9 +encodec>=0.1 +transformers +fairscale>=0.4 +sentencepiece +pyarrow>=14.0 +torchfcpe>=0.0.4 + +[dev] +pytest + +[eval] +pandas diff --git a/spiritlm.egg-info/top_level.txt b/spiritlm.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..e3a2641a4a49bd3769a9ec5017b539a2c49e0eaa --- /dev/null +++ b/spiritlm.egg-info/top_level.txt @@ -0,0 +1,2 @@ +spiritlm +tests diff --git a/spiritlm/__init__.py b/spiritlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17edf7fe569a6e89feb6107141c62c96ab684ea5 --- /dev/null +++ b/spiritlm/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. diff --git a/spiritlm/__pycache__/__init__.cpython-310.pyc b/spiritlm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9baadc2fd46258dffe01f383c5853a5104d055ce Binary files /dev/null and b/spiritlm/__pycache__/__init__.cpython-310.pyc differ diff --git a/spiritlm/eval/README.md b/spiritlm/eval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a60d669cf0864417c0f21d41765eb22fc13ce32e --- /dev/null +++ b/spiritlm/eval/README.md @@ -0,0 +1,92 @@ +# STSP Evaluation +The Speech-Text Sentiment Preservation (STSP) benchmark is made of a collection of speech and text prompts in the positive, negative or neutral sentiment. +Given a spoken or written prompt , the task consists in generating a text or speech sequence of tokens that preserves the sentiment of the prompt. + +The sentiment of the prompt is evaluated automatically with a sentiment/emotion classifier in speech or text depending of the output modality. +Based on these, we derive a STSP accuracy score. + +## Data Download +Download the data as well as the speech/text classifier checkpoints via this [link](https://dl.fbaipublicfiles.com/textless_nlp/spiritlm/stsp.tar.gz) +then extract the data into the folder `{spiritlm ROOT FOLDER}/data/stsp_data` +``` +cd {spiritlm ROOT FOLDER} +mkdir data/stsp_data +tar -xvzf stsp.tar.gz -C data/stsp_data --strip-components=1 +``` +Run the following script to check the dataset is all correctly present: +``` +python spiritlm/eval/stsp/sanity_check_download.py +``` +## Data structure +The dataset contains 3 folders: +- `data`: raw audio files +- `manifest`: data splits +- `model`: speech/text classifier checkpoints +### Data +The raw audio files for +- `emov`: EMOV +- `expresso/conversational`: EXPRESSO-ASR +- `expresso/read`: EXPRESSO-READ + +### Manifest +The train/validation/test splits, concretely we have: + +#### EMOV +- 1053 records for emov train split at `manifest/emov/emov.train.jsonl` +- 351 records for emov dev split at `manifest/emov/emov.dev.jsonl` +- 351 records for emov test split at `manifest/emov/emov.test.jsonl` + +#### EXPRESSO-ASR +- 1373 records for EXPRESSO-ASR train split at `manifest/expresso/expresso_asr.train` +- 479 records for EXPRESSO-ASR dev at `manifest/expresso/expresso_asr.dev.jsonl` +- 462 records for EXPRESSO-ASR test split at `manifest/expresso/expresso_asr.test.jsonl` + +#### EXPRESSO-READ +- 1024 records for EXPRESSO-READ train split at `manifest/expresso/expresso_read.train` +- 60 records for EXPRESSO-READ dev at `manifest/expresso/expresso_read.dev.jsonl` +- 54 records for EXPRESSO-READ test split at `manifest/expresso/expresso_read.test.jsonl` + +#### Few-shot Samples +The subset from EXPRESSO-ASR training set, used for the few-shot experiments: +- `s2s.jsonl`: S -> S direction +- `s2t.jsonl`: S -> T direction +- `t2t.jsonl`: T -> T direction +- `t2s.jsonl`: T -> S direction + +### Auto-Eval Speech And Text Classifiers + +The sentiment of the generated sequence is estimated in an auto-eval fashion with Speech and Text classifiers. We point to the [paper](https://arxiv.org/abs/2402.05755) for details on these classifiers. + + +## Prediction & Evaluation of Spirit LM on STSP (Speech/Text) + +```export PYTHONPATH=.``` + +Set `spiritlm` to the model you want to evaluate: e.g. ```spiritlm=spirit-lm-base-7b``` or ```spiritlm=spirit-lm-expressive-7b``` + +#### Speech to Text + torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_s_t.jsonl --input_output speech_text +#### Text to Text + torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_t_t.jsonl --input_output text_text +#### Text to Speech + torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_t_s.jsonl --input_output text_speech +#### Speech to Speech + torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_s_s.jsonl --input_output speech_speech + + +### Post-hoc Evaluation + +To evaluate the performance of a model different from SpiritLM, you can use the following evaluation script that takes as input a prediction.jsonl file. + +``` +python spiritlm/eval/eval_stsp.py --ref_file $REF_FILE --pred_file $pred_file +``` + +e.g. + +``` +python spiritlm/eval/eval_stsp.py \ +--ref_file ./data/examples/demo.jsonl \ +--pred_file ./data/examples/pred.jsonl +> Accuracy: 100.00% for predictions ./data/examples/pred.jsonl +``` diff --git a/spiritlm/eval/eval_stsp.py b/spiritlm/eval/eval_stsp.py new file mode 100644 index 0000000000000000000000000000000000000000..d37b12be8b2fc79a9e6f9ac4f3eb466f145648a7 --- /dev/null +++ b/spiritlm/eval/eval_stsp.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import json +from typing import Dict, Union + +import pandas as pd +from spiritlm.eval.stsp.utils import EMOTION_2_SENTIMENT + + +def load_pred(predictions): + ret = {} + with open(predictions) as f: + for line in f: + pred = json.loads(line) + ret[str(pred["id"])] = pred["pred"] + + assert sum(1 for _ in open(predictions)) == len(ret) + + return ret + + +def eval( + gold_records: str, predictions: Union[str, Dict], info_data="", label="sentiment" +): + n_gold_records = sum(1 for _ in open(gold_records)) + n_lines_pred = ( + sum(1 for _ in open(predictions)) + if isinstance(predictions, str) + else len(predictions) + ) + assert ( + n_gold_records == n_lines_pred + ), f"Mismatch between prediction ({n_lines_pred} samples in {predictions}) and reference ({n_gold_records} in {gold_records})" + + pred_dic = load_pred(predictions) if isinstance(predictions, str) else predictions + scores = [] + + with open(gold_records) as gold: + for line in gold: + ref = json.loads(line) + try: + if label in ref: + scores.append(pred_dic[str(ref["id"])] == ref[label]) + else: + assert label == "sentiment" and "emotion" in ref, ref + sentiment = EMOTION_2_SENTIMENT[ref["emotion"]] + scores.append(pred_dic[str(ref["id"])] == sentiment) + except Exception as e: + print( + f"ERROR in matching the predicted labels with the gold ones: {e}: ref['id'] do not match any key in {pred_dic}', {ref['id']}: " + ) + # TODO: add other metrics if needed : F1 per class, etc. + report = pd.DataFrame({"Correct": scores}) + if isinstance(predictions, str): + info_data += f"from {predictions}" + print( + f"Accuracy: {(report['Correct']==1).sum()/len(report)*100:0.2f}% for predictions {info_data}" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--ref_file", + type=str, + help="Path to reference record", + ) + parser.add_argument( + "--pred_file", + type=str, + help="Path to prediction: should be jsonl with each entry {'pred': , 'id': }", + ) + parser.add_argument( + "--label", + type=str, + default="sentiment", + help="sentiment or emotion", + ) + args = parser.parse_args() + + eval(args.ref_file, args.pred_file, label=args.label) diff --git a/spiritlm/eval/load_data.py b/spiritlm/eval/load_data.py new file mode 100644 index 0000000000000000000000000000000000000000..093a8d5154b1293e607c50a7e69f5a3a7eb0e81b --- /dev/null +++ b/spiritlm/eval/load_data.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import json +from pathlib import Path + +import torch +import torchaudio + + +class SpeechData(torch.utils.data.Dataset): + def __init__(self, manifest_dir, root_dir=None): + if root_dir is None: + root_dir = "." + self.root_dir = Path(root_dir) + self.manifest_dir = self.root_dir / manifest_dir + self.wav_field = "wav_path" + self.manifest = [json.loads(line.strip()) for line in open(manifest_dir)] + + def __getitem__(self, idx): + wav_path = self.root_dir / self.manifest[idx][self.wav_field] + return { + "wav": torchaudio.load(wav_path)[0].squeeze(0), + "id": str(self.manifest[idx]["id"]), + } + + def __len__(self): + return len(self.manifest) + + +class TextData(torch.utils.data.Dataset): + def __init__(self, manifest_dir, root_dir=None): + if root_dir is None: + root_dir = "." + self.root_dir = Path(root_dir) + self.manifest_dir = self.root_dir / manifest_dir + self.text_field = "asr" + self.manifest = [json.loads(line.strip()) for line in open(manifest_dir)] + + def __getitem__(self, idx): + return { + "text": self.manifest[idx][self.text_field], + "id": str(self.manifest[idx]["id"]), + } + + def __len__(self): + return len(self.manifest) diff --git a/spiritlm/eval/stsp/few_shot_prompt.py b/spiritlm/eval/stsp/few_shot_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..64e89cdfac0f4a78a67710b7cc2c6f5b60f519cc --- /dev/null +++ b/spiritlm/eval/stsp/few_shot_prompt.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import math +from typing import Union + +import pandas as pd +import torch +import torchaudio +from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MANIFEST_ROOT +from spiritlm.model.spiritlm_model import Spiritlm + +FEW_SHOT_MANIFEST_DIR = STSP_MANIFEST_ROOT / "few_shot" +FEW_SHOT_TEMPLATE = "{prompt}{generation}" + + +def wav_prompt(spiritlm_model: Spiritlm, wav: Union[str, torch.Tensor]) -> str: + return spiritlm_model.SPEECH_PROMPT_PREFIX + spiritlm_model.speech_tokenizer(wav) + + +def text_prompt(spiritlm_model: Spiritlm, text: str) -> str: + return spiritlm_model.TEXT_PROMPT_PREFIX + text + + +def _load_half_wav(wav_path: str, load_first_half: bool) -> torch.Tensor: + wav_path = STSP_DATA_ROOT / wav_path + wav = torchaudio.load(wav_path)[0].squeeze(0) + size = wav.size()[0] + half_size = size // 2 + if load_first_half: + wav = wav[:half_size] + else: + wav = wav[half_size:] + return wav + + +def build_few_shot_prompt( + spiritlm_model: Spiritlm, + input_output: str, + n_shots: int = 3, +) -> str: + """ + Build the few-shot prompt by simply concatenating a set of examples. + + E.g., a 3-shots T->S prompt would like this: + "[Text]text1[Speech]speech_tokens1\n[Text]text2[Speech]speech_tokens2\n[Text]text3[Speech]speech_tokens3\n" + """ + manifset_file_mapping = { + "text_text": "t2t", + "speech_text": "s2t", + "text_speech": "t2s", + "speech_speech": "s2s", + } + manifest_path = ( + FEW_SHOT_MANIFEST_DIR / f"{manifset_file_mapping[input_output]}.jsonl" + ) + df = pd.read_json(manifest_path, lines=True) + assert n_shots <= len(df) + + # ensure a balanced sampels for each sentiment + nb_samples_per_sentiment = math.ceil(n_shots / 3) + df = df.groupby("sentiment").sample(n=nb_samples_per_sentiment) + + prompts = [] + for _, row in df.iterrows(): + prompt = row["prompt"] + generation = row["generation"] + if input_output == "text_text": + prompt = FEW_SHOT_TEMPLATE.format( + prompt=text_prompt(spiritlm_model, prompt), + generation=text_prompt(spiritlm_model, generation), + ) + elif input_output == "text_speech": + prompt = FEW_SHOT_TEMPLATE.format( + prompt=text_prompt(spiritlm_model, prompt), + generation=wav_prompt( + spiritlm_model, _load_half_wav(generation, load_first_half=False) + ), + ) + elif input_output == "speech_text": + prompt = FEW_SHOT_TEMPLATE.format( + prompt=wav_prompt( + spiritlm_model, _load_half_wav(prompt, load_first_half=True) + ), + generation=text_prompt(spiritlm_model, generation), + ) + elif input_output == "speech_speech": + prompt = FEW_SHOT_TEMPLATE.format( + prompt=wav_prompt( + spiritlm_model, _load_half_wav(prompt, load_first_half=True) + ), + generation=wav_prompt( + spiritlm_model, _load_half_wav(generation, load_first_half=False) + ), + ) + prompts.append(prompt) + print(f"prompts: {prompts}") + return "\n".join(prompts) + "\n" diff --git a/spiritlm/eval/stsp/predict_stsp.py b/spiritlm/eval/stsp/predict_stsp.py new file mode 100644 index 0000000000000000000000000000000000000000..f155106e68584f3c302eb57f007fc2215a5fa8b8 --- /dev/null +++ b/spiritlm/eval/stsp/predict_stsp.py @@ -0,0 +1,299 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +""" +Usage example: + +cd {SPIRITLM ROOT FOLDER} +export PYTHONPATH=. + +# Speech to Text +torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_s_t.jsonl --input_output speech_text +# Text to Text +torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_t_t.jsonl --input_output text_text +# Text to Speech# +torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred._t_s.jsonl --input_output text_speech +# Speech to Speech +torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_s_s.jsonl --input_output speech_speech + +""" + +import argparse +import json +import os +import uuid +from pathlib import Path +from typing import Union + +import torch +import torch.distributed as dist +import torchaudio +from spiritlm.eval.eval_stsp import eval +from spiritlm.eval.load_data import SpeechData, TextData +from spiritlm.eval.stsp.few_shot_prompt import build_few_shot_prompt +from spiritlm.eval.stsp.sentiment_classifiers import ( + get_text_sentiment_prediction, + load_sentiment_classifier, +) +from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MODEL_ROOT +from spiritlm.eval.stsp.utils import ( + ExpressoEmotionClassifier, + load_emotion_classifier, + wav2emotion_and_sentiment, +) +from spiritlm.model.spiritlm_model import ( + ContentType, + GenerationInput, + InterleavedOutputs, + OutputModality, + Spiritlm, +) +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm +from transformers import AutoModelForSequenceClassification, GenerationConfig, set_seed + +SPEECH_CLASSIFIER = STSP_MODEL_ROOT / "speech_classifier" +TEXT_CLASSIFIER = STSP_MODEL_ROOT / "text_classifier" + +NB_RETRIES = 3 + + +def get_eval_classifier(args): + if args.input_output.endswith("speech"): + return load_emotion_classifier(str(SPEECH_CLASSIFIER)) + elif args.input_output.endswith("text"): + return load_sentiment_classifier(str(TEXT_CLASSIFIER)) + else: + raise (Exception(f"{args.input_output} not supported")) + + +def get_sentiment( + input_output, + generation, + classifer: Union[AutoModelForSequenceClassification, ExpressoEmotionClassifier], +): + if input_output.endswith("speech"): + _, pred_sentiment = wav2emotion_and_sentiment(generation, classifer) + elif input_output.endswith("text"): + _, pred_sentiment = get_text_sentiment_prediction(generation, classifer) + return pred_sentiment + + +def write_jsonl(dir: str, predictions: dict): + Path(dir).parent.mkdir(exist_ok=True, parents=True) + with open(dir, "w") as f: + for id, result_dict in predictions.items(): + record = {"id": id, **result_dict} + json_string = json.dumps(record) + f.write(json_string + "\n") # Add a newline to separate JSON objects + print(f"{dir} written") + + +def write_wav( + wav, + save_dir: Path, + sample_rate: int = 16_000, +) -> str: + """Save wav under `save_dir` with a random name and return the full path.""" + save_dir.mkdir(exist_ok=True, parents=True) + random_path = save_dir / (str(uuid.uuid4()) + ".wav") + torchaudio.save( + random_path, torch.from_numpy(wav).unsqueeze(0), sample_rate=sample_rate + ) + return str(random_path) + + +def run(args): + world_size = int(os.environ["WORLD_SIZE"]) + world_rank = int(os.environ["RANK"]) + print( + f"Running distributed inference with world_size: {world_size}, world_rank: {world_rank}" + ) + dist.init_process_group("nccl", rank=world_rank, world_size=world_size) + set_seed(args.seed) + spiritlm_model = Spiritlm(args.model) + evaluation_classifier = get_eval_classifier(args) + input_output = args.input_output + eval_manifest_path = args.eval_manifest_path + write_wav_output = args.write_wav_output + + if args.few_shot > 0: + prompt = build_few_shot_prompt( + spiritlm_model=spiritlm_model, + input_output=args.input_output, + n_shots=args.few_shot, + ) + else: + prompt = None + + # load + if input_output.startswith("speech"): + eval_dataset = SpeechData(eval_manifest_path, root_dir=STSP_DATA_ROOT) + elif input_output.startswith("text"): + eval_dataset = TextData(eval_manifest_path, root_dir=STSP_DATA_ROOT) + + sampler = DistributedSampler(dataset=eval_dataset) + loader = torch.utils.data.DataLoader( + dataset=eval_dataset, + batch_size=1, # large batch size is not supported yet + sampler=sampler, + num_workers=4, + ) + predictions = {} + if input_output.endswith("speech"): + output_modality = OutputModality.SPEECH + max_new_tokens = 300 + else: + output_modality = OutputModality.TEXT + max_new_tokens = 50 + for _, data in tqdm( + enumerate(loader), + desc=f"Predict {eval_manifest_path}", + total=eval_dataset.__len__() // world_size, + ): + # retry the generation multiple times because sometime it does not generate hubert tokens + for i in range(NB_RETRIES): + try: + out: InterleavedOutputs = spiritlm_model.generate( + output_modality=output_modality, + interleaved_inputs=[ + GenerationInput( + content=( + data["wav"][0] + if input_output.startswith("speech") + else data["text"][0] + ), # 0 because of batch size 1 + content_type=( + ContentType.SPEECH + if input_output.startswith("speech") + else ContentType.TEXT + ), + ) + ], + generation_config=GenerationConfig( + temperature=0.8, + top_p=0.95, + max_new_tokens=max_new_tokens, + do_sample=True, + ), + prompt=prompt, + ) + except Exception as e: + print(f"Got an exception when generating: {e}") + if i == NB_RETRIES - 1: + raise Exception(f"Failed to generate after {NB_RETRIES}") + else: + break + assert len(out) == 1 + generated_output = out[0].content + detected_sentiment = get_sentiment( + input_output, generated_output, evaluation_classifier + ) + if output_modality == OutputModality.TEXT: + generation = generated_output + elif write_wav_output and output_modality == OutputModality.SPEECH: + generation = write_wav(generated_output, Path(write_wav_output)) + else: + generation = None + result_dict = {"pred": detected_sentiment} + if generation is not None: + result_dict["generation"] = generation + predictions[str(data["id"][0])] = result_dict + + if args.eval: + gathered_predictions = [None for _ in range(world_size)] + dist.gather_object( + predictions, gathered_predictions if world_rank == 0 else None, dst=0 + ) + if world_rank == 0: + all_predictions = {k: v for d in gathered_predictions for k, v in d.items()} + eval( + eval_manifest_path, + {k: v["pred"] for k, v in all_predictions.items()}, + info_data=f"{eval_manifest_path}, input-output {input_output}", + label="sentiment", + ) + + if args.write_pred is not None and world_rank == 0: + write_jsonl(args.write_pred, all_predictions) + + +def setup_env(): + os.environ["OMP_NUM_THREADS"] = "1" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--eval_manifest_path", # data/examples/ref.jsonl + type=str, + help="Path to reference record", + required=True, + ) + + parser.add_argument( + "--data_root_dir", # data/stsp_data + type=str, + help=f"Path to root data folder, default to {str(STSP_DATA_ROOT)}", + default=str(STSP_DATA_ROOT), + required=False, + ) + + parser.add_argument( + "--model", + type=str, + default="spirit-lm-expressive-7b", + help="Model name (spirit-lm-base-7b or spirit-lm-expressive-7b) or path to model", + required=False, + ) + parser.add_argument( + "--few_shot", + type=int, + default=0, + help="Number of few shot examples, 3/6/9", + required=False, + ) + parser.add_argument( + "--input_output", + type=str, + default="speech_speech", + help="speech_speech speech_text text_speech text_text", + required=False, + ) + parser.add_argument( + "--eval_type", + type=str, + default="emotion", + required=False, + ) + parser.add_argument( + "--write_pred", + type=str, + default=None, + help="Path to save the predictions output", + required=False, + ) + parser.add_argument( + "--write_wav_output", + type=str, + default=None, + help="Path to save the generated audio if the output is speech", + required=False, + ) + parser.add_argument( + "--eval", + default=False, + action="store_true", + ) + parser.add_argument( + "--seed", + default=0, + type=int, + ) + + args = parser.parse_args() + setup_env() + run(args) diff --git a/spiritlm/eval/stsp/sanity_check_download.py b/spiritlm/eval/stsp/sanity_check_download.py new file mode 100644 index 0000000000000000000000000000000000000000..a96bd22baf354d204f0c7509cbf966b0fc08905c --- /dev/null +++ b/spiritlm/eval/stsp/sanity_check_download.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import json + +from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MANIFEST_ROOT + + +def check_all_datasets(): + for dataset_manifset in STSP_MANIFEST_ROOT.glob("**/*jsonl"): + records_checked = 0 + print(f"dataset_manifset: {dataset_manifset}") + with dataset_manifset.open() as f: + for record in f: + record = json.loads(record) + for wav_key in ["wav_path", "prompt", "generation"]: + if wav_key in record and record[wav_key].endswith(".wav"): + wav_path = STSP_DATA_ROOT / record[wav_key] + assert ( + wav_path.is_file() + ), f"Record {record[wav_key]} not found in {str(wav_path)} and listed in {dataset_manifset}" + records_checked += 1 + print(f"{records_checked} records checked for {dataset_manifset.stem} split") + + +if __name__ == "__main__": + check_all_datasets() diff --git a/spiritlm/eval/stsp/sentiment_classifiers.py b/spiritlm/eval/stsp/sentiment_classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c2e5c9c4df38edbc3bd4f98cad7f36eaee223a --- /dev/null +++ b/spiritlm/eval/stsp/sentiment_classifiers.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline + + +def pred_to_label( + sentiment_prediction_scores: List[List[Dict[str, Any]]], +) -> Tuple[str, float]: + if isinstance(sentiment_prediction_scores[0], list): + sentiment_prediction_scores = sentiment_prediction_scores[0] + item_with_max_score = max( + sentiment_prediction_scores, key=lambda _dict: _dict["score"] + ) + score = item_with_max_score["score"] + return score, item_with_max_score["label"].lower() + + +def get_text_sentiment_prediction(text: str, sentiment_classifier) -> Tuple[str, float]: + return pred_to_label(sentiment_classifier(text)) + + +def load_sentiment_classifier(model_dir: str): + classifier = pipeline( + task="text-classification", + model=AutoModelForSequenceClassification.from_pretrained(model_dir), + tokenizer=AutoTokenizer.from_pretrained( + "j-hartmann/sentiment-roberta-large-english-3-classes" + ), + top_k=None, + ) + return classifier diff --git a/spiritlm/eval/stsp/stsp_constants.py b/spiritlm/eval/stsp/stsp_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..26d429d6d79e6f2f3a7b11647d6a41b7597e7ef6 --- /dev/null +++ b/spiritlm/eval/stsp/stsp_constants.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from pathlib import Path + +STSP_ROOT = Path(__file__).parents[3] / "data" / "stsp_data" +STSP_DATA_ROOT = STSP_ROOT / "data" +STSP_MODEL_ROOT = STSP_ROOT / "model" +STSP_MANIFEST_ROOT = STSP_ROOT / "manifest" diff --git a/spiritlm/eval/stsp/utils.py b/spiritlm/eval/stsp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2f74a661e1ac37131901b28410d8eec3112a44 --- /dev/null +++ b/spiritlm/eval/stsp/utils.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from functools import cache +from typing import List, Optional, Tuple + +import torch +import torchaudio +from transformers import AutoFeatureExtractor, AutoModelForAudioClassification + +EXPRESSO_EMOTION_2_SENTIMENT = { + "happy": "positive", + "angry": "negative", + "sad": "negative", + "default": "neutral", +} + +EMOTION_2_SENTIMENT = { + "happy": "positive", + "angry": "negative", + "sad": "negative", + "default": "neutral", + "neutral": "neutral", + "amused": "positive", +} + + +@cache +def emotions2new_label_names_and_indices( + emotions_to_select: Tuple[str], + label_names: Tuple[str], +) -> Tuple[List[str], List[int]]: + emotion2index = {e: i for i, e in enumerate(label_names)} + sorted_indices_emotions = sorted( + [(emotion2index[emotion], emotion) for emotion in emotions_to_select] + ) + zipped = list(zip(*sorted_indices_emotions)) + return zipped + + +def expresso_emotion2_sentiment(emotion: str): + return EXPRESSO_EMOTION_2_SENTIMENT[emotion] + + +@dataclass +class ExpressoEmotionClassifier: + feature_extractor: AutoFeatureExtractor + model: AutoModelForAudioClassification + label_names: List[str] + + +def load_emotion_classifier(checkpoint_path: str) -> ExpressoEmotionClassifier: + feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint_path) + model = ( + AutoModelForAudioClassification.from_pretrained(checkpoint_path).cuda().eval() + ) + label_names = [model.config.id2label[i] for i in range(model.config.num_labels)] + print(f"Classification model loaded from {checkpoint_path} !") + return ExpressoEmotionClassifier(feature_extractor, model, label_names) + + +@torch.inference_mode() +def predict_audio( + audio, + expresso_emotion_classifier: ExpressoEmotionClassifier, + emotions_to_predict: Optional[List[str]] = None, +): + if isinstance(audio, str): + speech, _ = torchaudio.load(audio) + resampler = torchaudio.transforms.Resample( + expresso_emotion_classifier.feature_extractor.sampling_rate + ) + speech = resampler(speech).squeeze().numpy() + else: + speech = audio + + features = expresso_emotion_classifier.feature_extractor( + speech, + sampling_rate=expresso_emotion_classifier.feature_extractor.sampling_rate, + return_tensors="pt", + ) + features["input_values"] = features["input_values"].cuda() + + logits = expresso_emotion_classifier.model(**features).logits + if emotions_to_predict is not None: + (indices, label_names) = emotions2new_label_names_and_indices( + tuple(emotions_to_predict), tuple(expresso_emotion_classifier.label_names) + ) + logits = logits[:, indices] + else: + label_names = expresso_emotion_classifier.label_names + pred_id = torch.argmax(logits, dim=-1)[0].item() + + return label_names[pred_id], logits.detach().cpu().numpy() + + +def wav2emotion( + wav, + expresso_emotion_classifier: ExpressoEmotionClassifier, + emotions_to_predict: Optional[List[str]] = None, +) -> str: + label_logits = predict_audio( + audio=wav, + expresso_emotion_classifier=expresso_emotion_classifier, + emotions_to_predict=emotions_to_predict, + ) + pred_emotion = label_logits[0] + return pred_emotion + + +def wav2emotion_and_sentiment( + wav, + expresso_emotion_classifier: ExpressoEmotionClassifier, + emotions_to_predict: Optional[List[str]] = None, +) -> Tuple[str, str]: + pred_emotion = wav2emotion(wav, expresso_emotion_classifier, emotions_to_predict) + mapped_sentiment = expresso_emotion2_sentiment(pred_emotion) + return pred_emotion, mapped_sentiment diff --git a/spiritlm/eval/utils.py b/spiritlm/eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7170fed0091aa51169d4a65619604cca58eed757 --- /dev/null +++ b/spiritlm/eval/utils.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import torchaudio +from spiritlm.model.spiritlm_model import Spiritlm + + +def wav_prompt(spiritlm_model: Spiritlm, wav_path: str) -> str: + wav = torchaudio.load(wav_path)[0].squeeze(0) + return spiritlm_model.SPEECH_PROMPT_PREFIX + spiritlm_model.speech_tokenizer(wav) + + +def text_prompt(spiritlm_model: Spiritlm, text: str) -> str: + return spiritlm_model.TEXT_PROMPT_PREFIX + text diff --git a/spiritlm/model/README.md b/spiritlm/model/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c14278cc9310f709ac29e7cc463110a61cef7d86 --- /dev/null +++ b/spiritlm/model/README.md @@ -0,0 +1,82 @@ +# Model for Spirit LM +This repo includes the Spirit LM model wrapper. + +## Usage examples + +### Model Loading +```python +from spiritlm.model.spiritlm_model import Spiritlm + +# Spirit LM Base 7B +spirit_lm = Spiritlm("spirit-lm-base-7b") + +# Spirit LM Expressive 7B +spirit_lm = Spiritlm("spirit-lm-expressive-7b") +``` + +### Generation examples +```python +from spiritlm.model.spiritlm_model import OutputModality, GenerationInput, ContentType +from transformers import GenerationConfig + +# Generate only text +spirit_lm.generate( + output_modality=OutputModality.TEXT, + interleaved_inputs=[ + GenerationInput( + content="The largest country in the world is", + content_type=ContentType.TEXT, + ) + ], + generation_config=GenerationConfig( + temperature=0.9, + top_p=0.95, + max_new_tokens=50, + do_sample=True, + ), +) + +# Expected output format: +# [GenerationOuput(content='Russia, with an area of ...', content_type=)] + +# Generate only speech +spirit_lm.generate( + output_modality=OutputModality.SPEECH, + interleaved_inputs=[ + GenerationInput( + content="examples/audio/7143-88743-0029.flac", + content_type=ContentType.SPEECH, + ) + ], + generation_config=GenerationConfig( + temperature=0.9, + top_p=0.95, + max_new_tokens=200, + do_sample=True, + ), +) + +# Expected output format: +# [GenerationOuput(content=array([ 3.6673620e-05, 2.6468514e-04, 1.0735081e-03, ...,], dtype=float32), content_type=)] + + +# Arbitrary generation +spirit_lm.generate( + output_modality=OutputModality.ARBITRARY, + interleaved_inputs=[ + GenerationInput( + content="examples/audio/7143-88743-0029.flac", + content_type=ContentType.SPEECH, + ) + ], + generation_config=GenerationConfig( + temperature=0.9, + top_p=0.95, + max_new_tokens=200, + do_sample=True, + ), +) +# Expected output format is a list of GenerationOuput where content type could be `ContentType.TEXT' or `ContentType.SPEECH`: +# [GenerationOuput(content='xxx', content_type=), GenerationOuput(content=array([ 0.00553902, -0.03210586, ... ], dtype=float32), content_type=), GenerationOuput(content='yyy', content_type=), GenerationOuput(content=array([0.04051103, 0.03596291, 0.03381396, ..., 0.05103811, 0.05429034, ..,,], dtype=float32), content_type=)] +``` +See more examples with other modalites in [examples/speech_generation/spirit_model.ipynb](../../examples/speech_generation/spirit_model.ipynb). \ No newline at end of file diff --git a/spiritlm/model/__init__.py b/spiritlm/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17edf7fe569a6e89feb6107141c62c96ab684ea5 --- /dev/null +++ b/spiritlm/model/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. diff --git a/spiritlm/model/spiritlm_model.py b/spiritlm/model/spiritlm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2191683c3b5db17b611167b2ed7515635232c103 --- /dev/null +++ b/spiritlm/model/spiritlm_model.py @@ -0,0 +1,576 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import logging +import math +import os +from dataclasses import dataclass +from enum import Enum, auto +from functools import cache +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torchaudio +from spiritlm.model.utils import ( + convert_to_wav_tensor, + does_end_with_speech_token, + does_start_with_speech_token, + find_prompt_last_speech_start_position, + get_forbidden_tokens, +) +from spiritlm.speech_tokenizer import spiritlm_base, spiritlm_expressive +from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, set_seed + +_logger = logging.getLogger(__name__) + + +# Get the base checkpoints directory from environment variable or use the default base path +base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parent.parent.parent / "checkpoints")) + +# Append 'spiritlm_model' to the base path +CHECKPOINT_DIR = base_checkpoints_dir / "spiritlm_model" + +class ContentType(Enum): + TEXT = "TEXT" + SPEECH = "SPEECH" + + +class OutputModality(Enum): + TEXT = auto() + SPEECH = auto() + ARBITRARY = auto() + + +@dataclass +class GenerationInput: + content: Union[str, os.PathLike, torch.Tensor, np.ndarray] + content_type: ContentType + + @classmethod + def from_tuple(cls, tup): + content_type, content = tup + content_type = content_type.upper() + assert content_type in [ + "SPEECH", + "TEXT", + ], f"expects content_type to be one of ['SPEECH', 'TEXT'], found '{content_type}'" + if content_type == "TEXT": + content_type = ContentType.TEXT + elif content_type == "SPEECH": + content_type = ContentType.SPEECH + return cls(content=content, content_type=content_type) + + +@dataclass +class GenerationOuput: + content: Union[str, np.ndarray] + content_type: ContentType + + +InterleavedInputs = List[GenerationInput] +InterleavedOutputs = List[GenerationOuput] + + +class SpiritlmVariants(Enum): + BASE_7B = "spirit-lm-base-7b" + EXPRESSIVIE_7B = "spirit-lm-expressive-7b" + + @classmethod + def values_as_list(cls): + return [e.value for e in cls] + + +def _ensure_model_name(name: str): + if Path(name).exists(): + name = Path(name).stem + expected_names = SpiritlmVariants.values_as_list() + assert ( + name in SpiritlmVariants.values_as_list() + ), f"Unknown model name, expected one of {expected_names}" + + +def _set_device_and_return(): + if not torch.cuda.is_available(): + return "cpu" + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + return torch.device(local_rank) + + +def _convert_str_output_modality(output_modality): + """Convert from string to an instance of OutputModality""" + output_modality_str_map = { + "TEXT": OutputModality.TEXT, + "SPEECH": OutputModality.SPEECH, + "ARBITRARY": OutputModality.ARBITRARY, + } + if isinstance(output_modality, str): + output_modality = output_modality.upper() + assert ( + output_modality in output_modality_str_map + ), f"invalid string output_modality (found '{output_modality}', but expects one of {list(output_modality_str_map)})" + output_modality = output_modality_str_map[output_modality] + assert isinstance(output_modality, OutputModality) + return output_modality + + +def _get_generation_inputs(interleaved_inputs): + """Convert from a list of tuple (content_type, content) to a list of GenrationInput""" + for i, item in enumerate(interleaved_inputs): + assert isinstance(item, tuple) or isinstance(item, GenerationInput), ( + "Each element of interleaved_inputs is expected to be either an instance of GenerationInput " + "or a tuple of (content_modality, content)" + ) + if isinstance(item, tuple): + interleaved_inputs[i] = GenerationInput.from_tuple(interleaved_inputs[i]) + return interleaved_inputs + + +def _overwrite_generation_config(generation_config, kwargs): + """Overwrite generation_config from the kwargs""" + if generation_config is None: + generation_config = GenerationConfig() + assert isinstance(generation_config, GenerationConfig) + gen_diff_dict = generation_config.to_diff_dict() + for attr_name, attr_value in kwargs.items(): + assert hasattr( + generation_config, attr_name + ), f"attribute '{attr_name}' not found in transformers.GenerationConfig" + if attr_name in gen_diff_dict and attr_value != gen_diff_dict[attr_name]: + _logger.warning( + f"Overwrite generation_config's {attr_name} to {attr_value}" + ) + setattr(generation_config, attr_name, attr_value) + return generation_config + + +class Spiritlm: + TEXT_PROMPT_PREFIX = "[Text]" + SPEECH_PROMPT_PREFIX = "[Speech]" + + def __init__(self, name: str, **speech_tokenizer_kwargs): + if Path(name).exists(): + path = name + else: + path = CHECKPOINT_DIR / name + _ensure_model_name(name) + self.device = _set_device_and_return() + _logger.info(f"Loading SPIRIT-LM model from the path {path}...") + self.model = LlamaForCausalLM.from_pretrained( + path, torch_dtype=torch.bfloat16 + ).to(self.device) + _logger.info(f"SPIRIT-LM model is loaded.") + self.tokenizer = LlamaTokenizer.from_pretrained( + pretrained_model_name_or_path=path, + add_bos_token=True, + add_eos_token=False, + ) + _logger.info("Loading SPIRIT-LM speech tokenizers ...") + if name == SpiritlmVariants.BASE_7B.value: + self.speech_tokenizer = spiritlm_base(**speech_tokenizer_kwargs) + self.is_expressive_model = False + elif name == SpiritlmVariants.EXPRESSIVIE_7B.value: + self.speech_tokenizer = spiritlm_expressive(**speech_tokenizer_kwargs) + self.is_expressive_model = True + _logger.info("SPIRIT-LM speech tokenizers are loaded.") + + def _build_prompt( + self, + generation_inputs: List[GenerationInput], + output_modality: OutputModality, + ) -> str: + """ + Build the prompt according the input content and the output modality. + """ + if not isinstance(output_modality, OutputModality): + raise ValueError(f"Unknown output_modality: {output_modality}") + prompts = [] + prev_modality = None + for gen_input in generation_inputs: + if gen_input.content_type.value == ContentType.SPEECH.value: + gen_input.content = convert_to_wav_tensor(gen_input.content) + if prev_modality != "s": + prompts.append(Spiritlm.SPEECH_PROMPT_PREFIX) + prompts.append(self.speech_tokenizer(gen_input.content)) + prev_modality = "s" # speech + elif gen_input.content_type.value == ContentType.TEXT.value: + if prev_modality != "t": + prompts.append(Spiritlm.TEXT_PROMPT_PREFIX) + prompts.append(gen_input.content) + prev_modality = "t" # text + else: + raise ValueError( + f"Unknown content type: {gen_input.content_type.value}" + ) + if output_modality == OutputModality.TEXT: + if prev_modality != "t": + prompts.append(Spiritlm.TEXT_PROMPT_PREFIX) + elif output_modality == OutputModality.SPEECH: + if prev_modality != "s": + prompts.append(Spiritlm.SPEECH_PROMPT_PREFIX) + return "".join(prompts) + + @cache + def _build_forbidden_tokens( + self, + output_modality: OutputModality, + ) -> List[int]: + """ + Build a set of token ids that we don't want to generate according the modality direction. + + For instance, when the modality direction is speech to text (S2T), i.e., we continue + generating text given a speech prompt, we want that the output contains only the text tokens. + """ + if output_modality == OutputModality.TEXT: + forbidden_tokens = get_forbidden_tokens( + ban_special_tokens=True, + generate_only_text=True, + ban_expressivity_tokens=True if self.is_expressive_model else False, + ) + elif output_modality == OutputModality.SPEECH: + forbidden_tokens = get_forbidden_tokens( + ban_special_tokens=True, + generate_only_speech=True, + ) + elif output_modality == OutputModality.ARBITRARY: + forbidden_tokens = [] + else: + raise ValueError(f"Unknown output_modality: {output_modality}") + return forbidden_tokens + + def _parse_speech_and_text( + self, + generated_content: str, + ): + # TODO: clean this function, it is too long! + splits = [] + i = 0 + last_pos = len(generated_content) + char_and_types = [] + is_speech_token = False + is_text_token = False + text_prefix_length = len(Spiritlm.TEXT_PROMPT_PREFIX) + speech_prefix_length = len(Spiritlm.SPEECH_PROMPT_PREFIX) + while i < last_pos: + ch = generated_content[i] + j = i + if ch == "[": + if ( + j + text_prefix_length - 1 < last_pos + and generated_content[j : j + text_prefix_length] + == Spiritlm.TEXT_PROMPT_PREFIX + ): # text prefix token + j += text_prefix_length # skip "[Text] + elif ( + j + speech_prefix_length - 1 < last_pos + and generated_content[j : j + speech_prefix_length] + == Spiritlm.SPEECH_PROMPT_PREFIX + ): # speech prefix token + j += speech_prefix_length # skip "[Speech]" + elif j + 2 < last_pos and generated_content[j + 1 : j + 3] in ( + "Hu", + "Pi", + "St", + ): + j += 3 # skip "["" and Hu/Pi/St + while j < last_pos and generated_content[j] != "]": + j += 1 + j += 1 # skip "]" + is_speech_token = True + else: # other texts starting with "[" e.g., "[abc" + is_text_token = True + j += 1 + else: + is_text_token = True + while j < last_pos and generated_content[j] != "[": + j += 1 + + cur_content = generated_content[i:j] + if is_speech_token: + if len(char_and_types) and char_and_types[-1][1] == "t": + splits.append( + ( + "".join( + ( + content_and_type[0] + for content_and_type in char_and_types + ) + ), + "t", + ) + ) + char_and_types = [] + char_and_types.append((cur_content, "s")) # speech + elif is_text_token: + if len(char_and_types) and char_and_types[-1][1] == "s": + splits.append( + ( + "".join( + ( + content_and_type[0] + for content_and_type in char_and_types + ) + ), + "s", + ) + ) + char_and_types = [] + char_and_types.append((cur_content, "t")) # text + is_speech_token, is_text_token = False, False + i = j + if len(char_and_types): + if char_and_types[-1][1] == "t": + splits.append( + ( + "".join( + (content_and_type[0] for content_and_type in char_and_types) + ), + "t", + ) + ) + else: + splits.append( + ( + "".join( + (content_and_type[0] for content_and_type in char_and_types) + ), + "s", + ) + ) + return splits + + def _decode_from_generated_output( + self, + output_modality: OutputModality, + generated_content: str, + prompt: str, + speaker_id: int = 2, + ) -> InterleavedOutputs: + """ + Decode the generated tokens according the modality direction. + + If the output is text, we return what it is. + If the output is speech, we decode speech tokens by the speech tokenizer. + If the output is arbitrary, we decode the generated content according to the its modality. + """ + + def _decode( + modality: OutputModality, + gen: str, + ) -> InterleavedOutputs: + if modality == OutputModality.TEXT: + return [ + GenerationOuput( + content=gen, + content_type=ContentType.TEXT, + ) + ] + elif modality == OutputModality.SPEECH: + return [ + GenerationOuput( + content=self.speech_tokenizer.decode( + gen, speaker_id=speaker_id + ), + content_type=ContentType.SPEECH, + ) + ] + elif modality == OutputModality.ARBITRARY: + decoded_chunks = [] + for i, (chunk_content, chunk_modality) in enumerate( + self._parse_speech_and_text(gen) + ): + if chunk_modality == "s": + # TODO: the way of finding Hubert token could be false positive + nb_content_hubert_tokens = len(chunk_content.split("[Hu")) + decoded = _decode( + modality=OutputModality.SPEECH, + gen=chunk_content, + )[0] + if i == 0 and is_last_content_speech: + # edge case when the prompt ends with speech and the generation starts with speech + nb_prompt_hubert_tokens = ( + len(prompt[last_speech_start_pos:].split("[Hu")) - 1 + ) # minus the one in prefix + if nb_content_hubert_tokens - nb_prompt_hubert_tokens < 25: + # continued speech from the prompt is too short + continue + # we drop the prompt part from the generation + prompt_ratio = ( + nb_prompt_hubert_tokens / nb_content_hubert_tokens + ) + decoded.content = decoded.content[ + math.ceil(decoded.content.size * prompt_ratio) : + ] + elif i > 0 and nb_content_hubert_tokens < 25: + # new speech in generation is too short + continue + else: + decoded = _decode( + modality=OutputModality.TEXT, + gen=chunk_content, + )[0] + decoded_chunks.append(decoded) + return decoded_chunks + else: + raise ValueError(f"Unknown output_modality: {output_modality}") + + generated_new_content = generated_content[len(prompt) :].strip() + is_last_content_speech, last_speech_start_pos = False, 0 + if ( + output_modality == OutputModality.ARBITRARY + and does_end_with_speech_token(prompt) + and does_start_with_speech_token(generated_new_content) + ): + is_last_content_speech = True + last_speech_start_pos = find_prompt_last_speech_start_position(prompt) + # If the prompt ends with speech, we decode both the prompt and the generation + # because we probably don't have pitch and style tokens in the generation. + generated_new_content = generated_content[last_speech_start_pos:] + return _decode(output_modality, generated_new_content) + + def generate( + self, + interleaved_inputs: Optional[List[Union[GenerationInput, tuple]]] = None, + prompt: Optional[str] = None, + output_modality: Union[OutputModality, str] = OutputModality.ARBITRARY, + generation_config: Optional[GenerationConfig] = None, + force_tokens_to_output_modality: bool = True, + speaker_id: int = 2, + return_prompt: bool = False, + seed: Optional[int] = None, + **kwargs, # GenerationConfig args can be passing here + ) -> Union[InterleavedOutputs, Tuple[InterleavedOutputs, str]]: + """ + Speech/text generation given speech/text prompt. + + Parameters: + interleaved_inputs (List of `GenerationInput` or list of tuples): + List of speech/text inputs. + Each element can be an instance of `GenerationInput` or a tuple of (content_type, content) + Text content is string; Speech content is either audio path, audio tensor, or nummpy array. + The prompt will be built by interleaving them in order. + prompt (str): + The prompt in encoded tokens string, + e.g., "[Speech][Hu99][Hu38]...", "[Text]whatever text" or mix of speech & text. + output_modality (str or `OutputModality`): + 'TEXT' or OutputModality.TEXT: generate text + 'SPEECH' or OutputModality.SPEECH: generate speech + 'ARBITRARY' or OutputModality.ARBITRARY: generate arbitrary modality output (default) + generation_config (`GenerationConfig`): + Generation configuration used by Huggingface `generate` function. + force_tokens_to_output_modality (bool): + Whether to force generating tokens to the output modality that you specify in `output_modality`. + For instance, if the `output_modality` is TEXT and force_tokens_to_output_modality is True, + we force the model to generate only the text tokens. + speaker_id (int): + Speaker id, 0, 1, 2 or 3. + return_prompt (bool): + Whether to return the constructed prompt (could be used for debug). + **kwargs: + Directly passing arguments from transformers.GenerationConfig (e.g. temperature, max_new_tokens, do_sample). + See: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + """ + + if seed is not None: + _logger.info(f"Set seed to {seed}") + set_seed(seed) + + # Set the output modality + output_modality = _convert_str_output_modality(output_modality) + + # Get the input prompt + assert not ( + interleaved_inputs is None and prompt is None + ), "interleaved_inputs and prompt can not both be None" + if ( + prompt is not None + and interleaved_inputs is not None + and len(interleaved_inputs) > 0 + ): + _logger.warning( + "When prompt is specified, interleaved_inputs will not be used." + ) + if prompt is None: + if not isinstance(interleaved_inputs, list): + interleaved_inputs = [interleaved_inputs] + interleaved_inputs = _get_generation_inputs(interleaved_inputs) + prompt = self._build_prompt( + interleaved_inputs, + output_modality, + ) + + # Get input tensor + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + + # Get generation config from kwargs + generation_config = _overwrite_generation_config(generation_config, kwargs) + + # Get forbidden token ids + if ( + force_tokens_to_output_modality + and output_modality != OutputModality.ARBITRARY + ): + forbidden_token_ids = [ + [tok_id] for tok_id in self._build_forbidden_tokens(output_modality) + ] + else: + forbidden_token_ids = None + + # Perform the generation + generate_ids = self.model.generate( + **inputs, + generation_config=generation_config, + bad_words_ids=forbidden_token_ids, + pad_token_id=-1, + ) + + # Decode the output + gen = self.tokenizer.batch_decode( + generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + try: + decoded_output = self._decode_from_generated_output( + output_modality=output_modality, + generated_content=gen, + prompt=prompt, + speaker_id=speaker_id, + ) + except Exception as e: + _logger.error(f"Fail to decode the content: {gen[len(prompt) :].strip()}") + raise e + + if return_prompt: + return decoded_output, prompt + else: + return decoded_output + + +if __name__ == "__main__": + spirit_lm = Spiritlm("spirit-lm-expressive-7b") + # run several time to test speech text interleaved outputs + wav = torchaudio.load("examples/audio/7143-88743-0029.flac")[0].squeeze() + for i in range(5): + outs = spirit_lm.generate( + output_modality=OutputModality.ARBITRARY, + interleaved_inputs=[ + GenerationInput( + content=wav, + content_type=ContentType.SPEECH, + ) + ], + generation_config=GenerationConfig( + temperature=0.9, + top_p=0.95, + max_new_tokens=200, + do_sample=True, + ), + ) + print("-" * 100) + print(i) + print("-" * 100) + print(outs) diff --git a/spiritlm/model/utils.py b/spiritlm/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2380bb1553d15a9f02e04d154af9293d0963ba --- /dev/null +++ b/spiritlm/model/utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import os +import re +from io import BytesIO +from typing import List, Optional, Union + +import numpy as np +import torch +import torchaudio + +EXPECTED_SAMPLING_RATE = 16_000 + + +def find_prompt_last_speech_start_position(prompt: str) -> Optional[int]: + prev_end = None + # revert the prompt so we can search from right to left, the speech token patterns are also reverted. + for match in re.finditer("(\]\d+uH\[)|(\]\d+iP\[)|(\]\d+tS\[)", prompt[::-1]): + start, end = match.start(), match.end() + if prev_end is not None and start != prev_end: + return len(prompt) - prev_end + prev_end = end + if prev_end is None: + # speech token is not found in the prompt + return None + return len(prompt) - prev_end + + +def convert_to_wav_tensor( + content: Union[str, os.PathLike, torch.Tensor, np.ndarray] +) -> torch.Tensor: + if isinstance(content, os.PathLike) or isinstance(content, str): + audio_path = str(content) + wav, sr = torchaudio.load(audio_path) + if sr != EXPECTED_SAMPLING_RATE: + wav = torchaudio.functional.resample( + wav, orig_freq=sr, new_freq=EXPECTED_SAMPLING_RATE + ) + elif isinstance(content, np.ndarray): + wav = torch.from_numpy(content) + elif isinstance(content, bytes): + wav, sr = torchaudio.load(BytesIO(content)) + if sr != EXPECTED_SAMPLING_RATE: + wav = torchaudio.functional.resample( + wav, orig_freq=sr, new_freq=EXPECTED_SAMPLING_RATE + ) + else: + wav = content + + # TODO: what about stereo ? + + return wav.squeeze() + + +def does_start_with_speech_token(encoded_string) -> bool: + if ( + encoded_string is None or len(encoded_string) <= 4 + ): # shortest speech token is "[Hu1]" + return False + if encoded_string[0] != "[": + return False + end_pos = 1 + while end_pos < len(encoded_string): + if encoded_string[end_pos] == "]" and end_pos >= 4: + if any(encoded_string[1:3].startswith(tok) for tok in ["Hu", "Pi", "St"]): + return True + return False + # longest speech token is "[Huxxxxx]" + if end_pos >= 10: + return False + end_pos += 1 + return False + + +def does_end_with_speech_token(encoded_string: str) -> bool: + if ( + encoded_string is None or len(encoded_string) <= 4 + ): # shortest speech token is "[Hu1]" + return False + if encoded_string[-1] != "]": + return False + start_pos = len(encoded_string) - 2 + while start_pos >= 0: + if encoded_string[start_pos] == "[" and start_pos + 3 < len(encoded_string): + if any( + encoded_string[start_pos + 1 : start_pos + 3].startswith(tok) + for tok in ["Hu", "Pi", "St"] + ): + return True + return False + # longest speech token is "[Huxxxxx]" + if start_pos < len(encoded_string) - 10: + return False + start_pos -= 1 + return False + + +def get_forbidden_tokens( + ban_special_tokens: bool = True, + generate_only_speech: bool = False, + generate_only_text: bool = False, + ban_expressivity_tokens: bool = False, +) -> List[int]: + assert not ( + generate_only_speech and generate_only_text + ), "Nothing will be generated when generate_only_speech and generate_only_text is all True." + forbidden_tokens = [] + if ban_special_tokens: + forbidden_tokens += [ + 32000, + 32001, + ] # [Text], [Speech] + if generate_only_speech: + forbidden_tokens += list(range(32000)) + elif generate_only_text: + forbidden_tokens += list(range(32002, 32002 + 501)) # hubert tokens + if ban_expressivity_tokens: + forbidden_tokens += list(range(32503, 32503 + 64)) # pitch tokens + forbidden_tokens += list( + range(32567, 32567 + 100) + ) # forbidden style tokens + return forbidden_tokens diff --git a/spiritlm/speech_tokenizer/README.md b/spiritlm/speech_tokenizer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8f1740334475c88d3593b91af9e1c3c23827d8bc --- /dev/null +++ b/spiritlm/speech_tokenizer/README.md @@ -0,0 +1,42 @@ +# Speech Tokenization for Spirit LM + +This repo contains the speech encoder/decoder used for the Spirit LM. + +Here is an example of how to use spiritlm_tokenizer + +```python +import IPython.display as ipd +from spiritlm.speech_tokenizer import spiritlm_base, spiritlm_expressive + +tokenizer = spiritlm_base() # base version, only has hubert units +# tokenizer = spiritlm_expressive() # expressive version, with pitch & style units + +# Input audio +audio = "examples/audio/7143-88743-0029.flac" +print('Original audio:') +ipd.display(ipd.Audio(audio)) + +## encode_units +print('\nEncode audio into units (not deduplicated) \n', '-'*20) +units = tokenizer.encode_units(audio) +print(units) +# > {'audio': '.../audio/7143-88743-0029.flac', 'hubert': '99 49 38 149 149 71...'} + +## encode_string +print('\nEncode audio into string (deduplicated and sorted units) \n', '-'*20) +string_tokens = tokenizer.encode_string(audio) +print(string_tokens) +# > '[Hu99][Hu49][Hu38][Hu149][Hu71]...' + +## decode from units +print('\nDecode back to audio from units (not deduplicated) \n', '-'*20) +resyn_wav = tokenizer.decode(units, speaker_id=2, dur_pred=False) +ipd.display(ipd.Audio(resyn_wav, rate=16000)) + +## decode from string +print('\nDecode back to audio from string (deduplicated and sorted units) \n', '-'*20) +resyn_dedup_wav = tokenizer.decode(string_tokens, speaker_id=2) +ipd.display(ipd.Audio(resyn_dedup_wav, rate=16000)) +``` + +An example notebook can be found in [examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb](../../examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb). \ No newline at end of file diff --git a/spiritlm/speech_tokenizer/__init__.py b/spiritlm/speech_tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cafbd9bb04ca0ac544b5044576b7d451ec32afe1 --- /dev/null +++ b/spiritlm/speech_tokenizer/__init__.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from .f0 import spiritlm_expressive_f0 +from .hifigan import spiritlm_base_hifigan, spiritlm_expressive_hifigan_w2v2 +from .hubert import spiritlm_hubert +from .spiritlm_tokenizer import SpiritLMTokenizer +from .style_encoder import spiritlm_expressive_style_encoder_w2v2 + +# Trick to avoid reloading the same model twice when calling multiple times +HUBERT = None +HIFIGAN_BASE = None +F0 = None +STYLE_W2V2 = None +HIFIGAN_EXPRESSIVE_W2V2 = None + + +def spiritlm_base( + default_speaker=2, + default_style=8, # conv-default +): + # Hubert + global HUBERT + if HUBERT is None: + HUBERT = spiritlm_hubert() + + # Hifigan + global HIFIGAN_BASE + if HIFIGAN_BASE is None: + HIFIGAN_BASE = spiritlm_base_hifigan( + default_speaker=default_speaker, default_style=default_style + ) + + return SpiritLMTokenizer( + hubert_model=HUBERT, + hifigan_model=HIFIGAN_BASE, + ) + + +def spiritlm_expressive(f0_backbone="fcpe", default_speaker=2): + # Hubert + global HUBERT + if HUBERT is None: + HUBERT = spiritlm_hubert() + + # F0 + global F0 + if F0 is None: + F0 = spiritlm_expressive_f0(f0_backbone=f0_backbone) + + # Style + global STYLE_W2V2 + if STYLE_W2V2 is None: + STYLE_W2V2 = spiritlm_expressive_style_encoder_w2v2() + + # Hifigan + global HIFIGAN_EXPRESSIVE_W2V2 + if HIFIGAN_EXPRESSIVE_W2V2 is None: + HIFIGAN_EXPRESSIVE_W2V2 = spiritlm_expressive_hifigan_w2v2( + default_speaker=default_speaker + ) + + return SpiritLMTokenizer( + hubert_model=HUBERT, + pitch_model=F0, + style_model=STYLE_W2V2, + hifigan_model=HIFIGAN_EXPRESSIVE_W2V2, + hubert_key="hubert", + pitch_key="pitch", + style_key="style", + ) diff --git a/spiritlm/speech_tokenizer/__pycache__/__init__.cpython-310.pyc b/spiritlm/speech_tokenizer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d2f6dc91d5110122267133993bb40f9a16ed86 Binary files /dev/null and b/spiritlm/speech_tokenizer/__pycache__/__init__.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/__pycache__/spiritlm_tokenizer.cpython-310.pyc b/spiritlm/speech_tokenizer/__pycache__/spiritlm_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ea1e0b993eb66b299f1ec0d3534ac73b635fab3 Binary files /dev/null and b/spiritlm/speech_tokenizer/__pycache__/spiritlm_tokenizer.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/f0/__init__.py b/spiritlm/speech_tokenizer/f0/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..839ef038834554c6d2945b7e806f76a266fdc22d --- /dev/null +++ b/spiritlm/speech_tokenizer/f0/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from pathlib import Path +import os + +import torch + +from .f0_tokenizer import F0Tokenizer + +# Get the base checkpoints directory from environment variable or use the default base path +base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parents[3] / "checkpoints")) + +# Append 'speech_tokenizer' to the base path +CHECKPOINT_DIR = base_checkpoints_dir / "speech_tokenizer" + +CURRENT_DEVICE = ( + torch.device(torch.cuda.current_device()) + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" +) + + +def spiritlm_expressive_f0(f0_backbone="fcpe"): + return F0Tokenizer( + f0_extractor_method=f0_backbone, + quantizer_path=CHECKPOINT_DIR / "vqvae_f0_quantizer/model.pt", + hop_length=80, + sampling_rate=16000, + interpolate=True, + device=CURRENT_DEVICE, + ) diff --git a/spiritlm/speech_tokenizer/f0/__pycache__/__init__.cpython-310.pyc b/spiritlm/speech_tokenizer/f0/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99af0db9117db87db58d79adb90edd61ef4ff83e Binary files /dev/null and b/spiritlm/speech_tokenizer/f0/__pycache__/__init__.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/f0/__pycache__/f0_extractor.cpython-310.pyc b/spiritlm/speech_tokenizer/f0/__pycache__/f0_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1739221e5b265c867cf4cd62dbc7cd5ae532aa45 Binary files /dev/null and b/spiritlm/speech_tokenizer/f0/__pycache__/f0_extractor.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/f0/__pycache__/f0_tokenizer.cpython-310.pyc b/spiritlm/speech_tokenizer/f0/__pycache__/f0_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30af78f08f8bc3430d37503634299fcddf067904 Binary files /dev/null and b/spiritlm/speech_tokenizer/f0/__pycache__/f0_tokenizer.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/f0/__pycache__/vqvae.cpython-310.pyc b/spiritlm/speech_tokenizer/f0/__pycache__/vqvae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3393c272f971323cc1f79f08f016be3274df0e85 Binary files /dev/null and b/spiritlm/speech_tokenizer/f0/__pycache__/vqvae.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/f0/f0_extractor.py b/spiritlm/speech_tokenizer/f0/f0_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..f80d61b5b1d1b2440665c92e54900067bb73153a --- /dev/null +++ b/spiritlm/speech_tokenizer/f0/f0_extractor.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + + +import logging + +import numpy as np +import torch +import torch.nn as nn +import torchaudio + +_logger = logging.getLogger(__name__) + + +class F0Extractor(nn.Module): + + def __init__( + self, + hop_length=80, + sampling_rate=16000, + interpolate=True, + ): + """Each second will have sampling_rate/hop_length frames.""" + super().__init__() + + self.hop_length = hop_length + self.sampling_rate = sampling_rate + self.interpolate = interpolate + + def load_audio(self, path, mono=True): + wav, sr = torchaudio.load(path) + if sr != self.sampling_rate: + wav = torchaudio.functional.resample( + wav, orig_freq=sr, new_freq=self.sampling_rate + ) + if mono and wav.ndim == 2: + wav = wav.mean(dim=0) + wav = wav.numpy() + return wav + + def compute_f0_uv(self, wav, interpolate=True): + raise NotImplementedError("Not implemented!") + + @torch.inference_mode() + def forward(self, audio, vuv=False): + if isinstance(audio, str): + audio = self.load_audio(audio) + + f0, uv = self.compute_f0_uv(audio, interpolate=self.interpolate) + + if not vuv: + return f0 + else: + return f0, uv + + +class pYAAPTF0Extractor(F0Extractor): + + def compute_f0_uv(self, wav, interpolate=True): + pitch = self.get_pitch(wav) + # take interpolate, otherwise pitch.samp_values + # pyaapt has some problems with pitch.samp_values, so do it manually (from pgslm) + f0 = pitch.samp_values + if interpolate: + f0 = self.interpolate_f0(f0) + vuv = pitch.vuv + return f0, vuv + + def get_pitch(self, wav): + try: + import amfm_decompy.basic_tools as basic + import amfm_decompy.pYAAPT as pYAAPT + from librosa.util import normalize + except ImportError as error: + raise ImportError( + "To use pYAAPTF0Extractor, please install AMFM-decompy and librosa" + ) from error + + wav = wav.squeeze() + assert wav.ndim == 1 + if not isinstance(wav, np.ndarray): + wav = np.array(wav) + frame_length = 20.0 # ms + to_pad = int(frame_length / 1000 * self.sampling_rate) // 2 + + # remove remainders for large hop length + n_frames = len(wav) // self.hop_length * self.hop_length + wav = wav[:n_frames] + + audio = normalize(wav) * 0.95 + if self.hop_length == 80: + audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0) + audio = basic.SignalObj(audio, self.sampling_rate) + pitch = pYAAPT.yaapt( + audio, + frame_length=frame_length, + frame_space=self.hop_length / self.sampling_rate * 1000, + nccf_thresh1=0.25, + tda_frame_length=25.0, + ) + + return pitch + + def interpolate_f0(self, f0, fill_extremities=True): + try: + from scipy.interpolate import interp1d + except ImportError as error: + raise ImportError( + "To use pYAAPTF0Extractor, please install scipy (`pip install scipy`)" + ) from error + + orig_t = np.arange(f0.shape[0]) + f0_interp = f0[:] + ii = f0_interp != 0 + if ii.sum() > 1: + f0_interp = interp1d( + orig_t[ii], + f0_interp[ii], + bounds_error=False, + kind="linear", + fill_value=0, + )(orig_t) + + # Fill extreme values with border values + if fill_extremities: + f0_interp[: orig_t[ii][0]] = f0_interp[ii][0] + f0_interp[orig_t[ii][-1] + 1 :] = f0_interp[ii][-1] + + return f0_interp + + +class FCPEF0Extractor(F0Extractor): + + def __init__( + self, + hop_length=80, + sampling_rate=16000, + interpolate=True, + device=None, + ): + try: + from torchfcpe import spawn_bundled_infer_model + except ImportError as error: + raise ImportError( + "To use FCPEF0Extractor, please install torchfcpe (`pip install torchfcpe`)" + ) from error + + super().__init__( + hop_length=hop_length, sampling_rate=sampling_rate, interpolate=interpolate + ) + + self.model = spawn_bundled_infer_model(device=device) + + def compute_f0_uv(self, wav, interpolate=True): + wav = wav.squeeze() + assert wav.ndim == 1 + f0_target_length = (len(wav) // self.hop_length) + 1 + if not isinstance(wav, torch.Tensor): + wav = torch.from_numpy(wav) + wav = wav.float().unsqueeze(0).unsqueeze(-1) + f0, uv = self.model.infer( + wav, + sr=self.sampling_rate, + decoder_mode="local_argmax", # Recommended mode + threshold=0.05, # Threshold for V/UV decision + f0_min=50, # Minimum pitch + f0_max=1100, # Maximum pitch + interp_uv=interpolate, # Interpolate unvoiced frames + output_interp_target_length=f0_target_length, # Interpolate to target length + retur_uv=True, + ) + vuv = 1 - uv + return f0.squeeze().cpu().numpy(), vuv.squeeze().cpu().numpy() + + +def load_f0_extractor( + f0_extractor_method, hop_length, sampling_rate, interpolate, device=None +): + expected_methods = ["pyaapt", "fcpe"] + assert ( + f0_extractor_method in expected_methods + ), f"Unexpected f0 extractor method: {f0_extractor_method} (choices are: {expected_methods})" + if f0_extractor_method == "pyaapt": + f0_extractor = pYAAPTF0Extractor( + hop_length=hop_length, sampling_rate=sampling_rate, interpolate=interpolate + ) + elif f0_extractor_method == "fcpe": + f0_extractor = FCPEF0Extractor( + hop_length=hop_length, + sampling_rate=sampling_rate, + interpolate=interpolate, + device=device, + ) + _logger.info( + f"Using '{f0_extractor_method}' f0 extractor method (choices are: {expected_methods})" + ) + return f0_extractor diff --git a/spiritlm/speech_tokenizer/f0/f0_tokenizer.py b/spiritlm/speech_tokenizer/f0/f0_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a262cf24327ff9dc933070c254a8d09cae719ec8 --- /dev/null +++ b/spiritlm/speech_tokenizer/f0/f0_tokenizer.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + + +import logging +import os + +import torch + +from .f0_extractor import load_f0_extractor +from .vqvae import load_vqvae + +_logger = logging.getLogger(__name__) + + +class F0Tokenizer(torch.nn.Module): + + def __init__( + self, + f0_extractor_method, + quantizer_path, + f0_speaker_stats=None, + hop_length=80, + sampling_rate=16000, + interpolate=False, + device="cuda", + ): + super().__init__() + + self.f0_extractor = load_f0_extractor( + f0_extractor_method=f0_extractor_method, + hop_length=hop_length, + sampling_rate=sampling_rate, + interpolate=interpolate, + device=device, + ) + + self.quantizer, self.quantizer_cfg = load_vqvae(quantizer_path) + self.quantizer.eval() + self.quantizer.to(device) + # Load speaker stats + self.speaker_f0_stats = f0_speaker_stats + if self.speaker_f0_stats is None and ( + self.quantizer_cfg.get("speaker_norm", False) + or "norm_" in self.quantizer_cfg.features + ): + speaker_stats_path = self.quantizer_cfg.get("speaker_stats", None) + if speaker_stats_path is not None and os.path.exists(speaker_stats_path): + self.speaker_f0_stats = torch.load( + speaker_stats_path, weights_only=True + ) + _logger.info(f"Speaker f0 stats loaded from '{speaker_stats_path}'") + else: + _logger.info( + "It seems that model is using normalized f0 but no speaker stats is given, will infer mean f0 from input utterance." + ) + + # this is useful for determining the device + self.register_buffer( + "_float_tensor", torch.tensor([0], dtype=torch.float, device=device) + ) + + @property + def device(self): + return self._float_tensor.device + + def quantize_vqvae(self, f0, vuv, speaker=None, compute_vqvae_pred=False): + assert self.quantizer_cfg.features in [ + "f0_interp,vuv", + "f0,vuv", + "norm_f0_interp,vuv", + "norm_f0,vuv", + ], self.quantizer_cfg.features + + if not isinstance(f0, torch.Tensor): + f0 = torch.tensor(f0) + if not isinstance(vuv, torch.Tensor): + vuv = torch.tensor(vuv) + + # normalize f0 + if ( + self.quantizer_cfg.get("speaker_norm", False) + or "norm_" in self.quantizer_cfg.features + ): + mask = f0 != 0 + if speaker is not None and speaker in self.speaker_f0_stats: + mean = self.speaker_f0_stats[speaker]["f0_mean"] + else: + # Get statistics from utterance (maybe it is more accurate to get mean from voiced segments) + vuv_mask = vuv != 0 + mean = torch.mean(f0[vuv_mask]) + f0[mask] = f0[mask] - mean + + x = torch.stack([f0, vuv]) # (2, T) + x = x.float().unsqueeze(0).to(self.device) # (1, 2, T) + if not compute_vqvae_pred: + quant_f0 = self.quantizer(x, compute_pred=False) + quant_f0 = quant_f0[0].squeeze(0) + return quant_f0 + else: + quant_f0, pred = self.quantizer(x, compute_pred=True) + quant_f0 = quant_f0[0].squeeze(0) + pred = pred[0] + return quant_f0, pred + + def forward(self, x, speaker=None, dense=False, compute_vqvae_pred=False): + f0, vuv = self.f0_extractor(x, vuv=True) + if dense: + return f0 + return self.quantize_vqvae( + f0, vuv, speaker=speaker, compute_vqvae_pred=compute_vqvae_pred + ) diff --git a/spiritlm/speech_tokenizer/f0/vqvae.py b/spiritlm/speech_tokenizer/f0/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..e5732e02b826e3ac0d055920386dcb227a95867f --- /dev/null +++ b/spiritlm/speech_tokenizer/f0/vqvae.py @@ -0,0 +1,654 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +# VQ-VAE model, adapted from: +# - https://github.com/openai/jukebox/blob/master/jukebox/vqvae/ +# - https://github.com/facebookresearch/speech-resynthesis/blob/main/modules/vq.py + + +import logging +import math +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf + +_logger = logging.getLogger(__name__) + + +def load_vqvae(checkpoint): + config = Path(checkpoint).parent / "config.yaml" + cfg = OmegaConf.load(config) + model = VQVAE(cfg) + state_dict = torch.load(checkpoint, map_location="cpu", weights_only=False) + model.load_state_dict(state_dict["model"]) + model.eval() + _logger.info(f"VQVAE model loaded from '{checkpoint}'!") + return model, cfg + + +class VQVAE(nn.Module): + def __init__(self, h): + super().__init__() + + self.encoder = Encoder(**h.encoder_params) + self.vq = Bottleneck(**h.vq_params) + self.decoder = Decoder(**h.decoder_params) + + def forward(self, x, compute_pred=False): + with torch.no_grad(): + z = self.encoder(x) + codes, z_q, commit_losses, metrics = self.vq(z) + + if not compute_pred: + return codes + x_hat = self.decoder(z_q) + + return codes, x_hat + + +class BottleneckBlock(nn.Module): + def __init__(self, k_bins, emb_width, mu): + super().__init__() + self.k_bins = k_bins + self.emb_width = emb_width + self.mu = mu + self.reset_k() + self.threshold = 1.0 + + def reset_k(self): + self.init = False + self.k_sum = None + self.k_elem = None + self.register_buffer("k", torch.zeros(self.k_bins, self.emb_width)) + + def _tile(self, x): + d, ew = x.shape + if d < self.k_bins: + n_repeats = (self.k_bins + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def init_k(self, x): + mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins + self.init = True + # init k_w using random vectors from x + y = self._tile(x) + _k_rand = y[torch.randperm(y.shape[0])][:k_bins] + dist.broadcast(_k_rand, 0) + self.k = _k_rand + assert self.k.shape == (k_bins, emb_width) + self.k_sum = self.k + self.k_elem = torch.ones(k_bins, device=self.k.device) + + def restore_k(self, num_tokens=None, threshold=1.0): + mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins + self.init = True + assert self.k.shape == (k_bins, emb_width) + self.k_sum = self.k.clone() + self.k_elem = torch.ones(k_bins, device=self.k.device) + if num_tokens is not None: + expected_usage = num_tokens / k_bins + self.k_elem.data.mul_(expected_usage) + self.k_sum.data.mul_(expected_usage) + self.threshold = threshold + + def update_k(self, x, x_l): + mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins + with torch.no_grad(): + # Calculate new centres + x_l_onehot = torch.zeros( + k_bins, x.shape[0], device=x.device + ) # k_bins, N * L + x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1) + + _k_sum = torch.matmul(x_l_onehot, x) # k_bins, w + _k_elem = x_l_onehot.sum(dim=-1) # k_bins + y = self._tile(x) + _k_rand = y[torch.randperm(y.shape[0])][:k_bins] + + dist.broadcast(_k_rand, 0) + dist.all_reduce(_k_sum) + dist.all_reduce(_k_elem) + + # Update centres + old_k = self.k + self.k_sum = mu * self.k_sum + (1.0 - mu) * _k_sum # w, k_bins + self.k_elem = mu * self.k_elem + (1.0 - mu) * _k_elem # k_bins + usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float() + self.k = ( + usage + * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) + + (1 - usage) * _k_rand + ) + _k_prob = _k_elem / torch.sum( + _k_elem + ) # x_l_onehot.mean(dim=-1) # prob of each bin + entropy = -torch.sum( + _k_prob * torch.log(_k_prob + 1e-8) + ) # entropy ie how diverse + used_curr = (_k_elem >= self.threshold).sum() + usage = torch.sum(usage) + dk = torch.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape)) + return dict(entropy=entropy, used_curr=used_curr, usage=usage, dk=dk) + + def preprocess(self, x): + # NCT -> NTC -> [NT, C] + x = x.permute(0, 2, 1).contiguous() + x = x.view(-1, x.shape[-1]) # x_en = (N * L, w), k_j = (w, k_bins) + + if x.shape[-1] == self.emb_width: + prenorm = torch.norm(x - torch.mean(x)) / np.sqrt(np.prod(x.shape)) + elif x.shape[-1] == 2 * self.emb_width: + x1, x2 = x[..., : self.emb_width], x[..., self.emb_width :] + prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( + torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape)) + ) + + # Normalise + x = x1 + x2 + else: + assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}" + return x, prenorm + + def postprocess(self, x_l, x_d, x_shape): + # [NT, C] -> NTC -> NCT + N, T = x_shape + x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() + x_l = x_l.view(N, T) + return x_l, x_d + + def quantise(self, x): + # Calculate latent code x_l + k_w = self.k.t() + distance = ( + torch.sum(x**2, dim=-1, keepdim=True) + - 2 * torch.matmul(x, k_w) + + torch.sum(k_w**2, dim=0, keepdim=True) + ) # (N * L, b) + min_distance, x_l = torch.min(distance, dim=-1) + fit = torch.mean(min_distance) + return x_l, fit + + def dequantise(self, x_l): + x = F.embedding(x_l, self.k) + return x + + def encode(self, x): + N, width, T = x.shape + + # Preprocess. + x, prenorm = self.preprocess(x) + + # Quantise + x_l, fit = self.quantise(x) + + # Postprocess. + x_l = x_l.view(N, T) + return x_l + + def decode(self, x_l): + N, T = x_l.shape + width = self.emb_width + + # Dequantise + x_d = self.dequantise(x_l) + + # Postprocess + x_d = x_d.view(N, T, width).permute(0, 2, 1).contiguous() + return x_d + + def forward(self, x, update_k=True): + N, width, T = x.shape + + # Preprocess + x, prenorm = self.preprocess(x) + + # Init k if not inited + if update_k and not self.init: + self.init_k(x) + + # Quantise and dequantise through bottleneck + x_l, fit = self.quantise(x) + x_d = self.dequantise(x_l) + + # Update embeddings + if update_k and self.training: + update_metrics = self.update_k(x, x_l) + else: + update_metrics = {} + + # Loss + commit_loss = torch.norm(x_d.detach() - x) ** 2 / np.prod(x.shape) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_l, x_d = self.postprocess(x_l, x_d, (N, T)) + return x_l, x_d, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics) + + +class Bottleneck(nn.Module): + def __init__(self, l_bins, emb_width, mu, levels): + super().__init__() + self.levels = levels + level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu) + self.level_blocks = nn.ModuleList() + for level in range(self.levels): + self.level_blocks.append(level_block(level)) + + def encode(self, xs): + zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)] + return zs + + def decode(self, zs, start_level=0, end_level=None): + if end_level is None: + end_level = self.levels + xs_quantised = [ + level_block.decode(z) + for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs) + ] + return xs_quantised + + def forward(self, xs): + zs, xs_quantised, commit_losses, metrics = [], [], [], [] + for level in range(self.levels): + level_block = self.level_blocks[level] + x = xs[level] + z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training) + zs.append(z) + if not self.training: + # Be extra paranoid and make sure the encoder weights can't + # change from straight-through estimator + x_quantised = x_quantised.detach() + xs_quantised.append(x_quantised) + commit_losses.append(commit_loss) + if self.training: + metrics.append(metric) + return zs, xs_quantised, commit_losses, metrics + + +class ResConvBlock(nn.Module): + def __init__(self, n_in, n_state): + super().__init__() + self.model = nn.Sequential( + nn.ReLU(), + nn.Conv2d(n_in, n_state, 3, 1, 1), + nn.ReLU(), + nn.Conv2d(n_state, n_in, 1, 1, 0), + ) + + def forward(self, x): + return x + self.model(x) + + +class Resnet(nn.Module): + def __init__(self, n_in, n_depth, m_conv=1.0): + super().__init__() + self.model = nn.Sequential( + *[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)] + ) + + def forward(self, x): + return self.model(x) + + +class ResConv1DBlock(nn.Module): + def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0): + super().__init__() + padding = dilation + self.model = nn.Sequential( + nn.ReLU(), + nn.Conv1d(n_in, n_state, 3, 1, padding, dilation), + nn.ReLU(), + nn.Conv1d(n_state, n_in, 1, 1, 0), + ) + if zero_out: + out = self.model[-1] + nn.init.zeros_(out.weight) + nn.init.zeros_(out.bias) + self.res_scale = res_scale + + def forward(self, x): + return x + self.res_scale * self.model(x) + + +class Resnet1D(nn.Module): + def __init__( + self, + n_in, + n_depth, + m_conv=1.0, + dilation_growth_rate=1, + dilation_cycle=None, + zero_out=False, + res_scale=False, + reverse_dilation=False, + checkpoint_res=False, + ): + super().__init__() + + def _get_depth(depth): + if dilation_cycle is None: + return depth + else: + return depth % dilation_cycle + + blocks = [ + ResConv1DBlock( + n_in, + int(m_conv * n_in), + dilation=dilation_growth_rate ** _get_depth(depth), + zero_out=zero_out, + res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth), + ) + for depth in range(n_depth) + ] + if reverse_dilation: + blocks = blocks[::-1] + self.checkpoint_res = checkpoint_res + if self.checkpoint_res == 1: + if dist.get_rank() == 0: + _logger.warning("Checkpointing convs") + self.blocks = nn.ModuleList(blocks) + else: + self.model = nn.Sequential(*blocks) + + def forward(self, x): + if self.checkpoint_res == 1: + raise NotImplementedError("Checkpoint not implemented") + else: + return self.model(x) + + +def assert_shape(x, exp_shape): + assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}" + + +class EncoderConvBlock(nn.Module): + def __init__( + self, + input_emb_width, + output_emb_width, + down_t, + stride_t, + width, + depth, + m_conv, + dilation_growth_rate=1, + dilation_cycle=None, + zero_out=False, + res_scale=False, + ): + super().__init__() + blocks = [] + if type(stride_t) is tuple or type(stride_t) is list: + start = True + for s_t, d_t in zip(stride_t, down_t): + if s_t % 2 == 0: + filter_t, pad_t = s_t * 2, s_t // 2 + else: + filter_t, pad_t = s_t * 2 + 1, s_t // 2 + 1 + if d_t > 0: + for i in range(d_t): + block = nn.Sequential( + nn.Conv1d( + input_emb_width if i == 0 and start else width, + width, + filter_t, + s_t, + pad_t, + ), + Resnet1D( + width, + depth, + m_conv, + dilation_growth_rate, + dilation_cycle, + zero_out, + res_scale, + ), + ) + blocks.append(block) + start = False + block = nn.Conv1d(width, output_emb_width, 3, 1, 1) + blocks.append(block) + else: + filter_t, pad_t = stride_t * 2, stride_t // 2 + if down_t > 0: + for i in range(down_t): + block = nn.Sequential( + nn.Conv1d( + input_emb_width if i == 0 else width, + width, + filter_t, + stride_t, + pad_t, + ), + Resnet1D( + width, + depth, + m_conv, + dilation_growth_rate, + dilation_cycle, + zero_out, + res_scale, + ), + ) + blocks.append(block) + block = nn.Conv1d(width, output_emb_width, 3, 1, 1) + blocks.append(block) + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) + + +class DecoderConvBock(nn.Module): + def __init__( + self, + input_emb_width, + output_emb_width, + down_t, + stride_t, + width, + depth, + m_conv, + dilation_growth_rate=1, + dilation_cycle=None, + zero_out=False, + res_scale=False, + reverse_decoder_dilation=False, + checkpoint_res=False, + ): + super().__init__() + blocks = [] + + if type(stride_t) is tuple or type(stride_t) is list: + block = nn.Conv1d(output_emb_width, width, 3, 1, 1) + blocks.append(block) + for k, (s_t, d_t) in enumerate(zip(stride_t, down_t)): + if d_t > 0: + if s_t % 2 == 0: + filter_t, pad_t = s_t * 2, s_t // 2 + else: + filter_t, pad_t = s_t * 2 + 1, s_t // 2 + 1 + end = k == len(stride_t) - 1 + for i in range(d_t): + block = nn.Sequential( + Resnet1D( + width, + depth, + m_conv, + dilation_growth_rate, + dilation_cycle, + zero_out=zero_out, + res_scale=res_scale, + reverse_dilation=reverse_decoder_dilation, + checkpoint_res=checkpoint_res, + ), + nn.ConvTranspose1d( + width, + input_emb_width if i == (d_t - 1) and end else width, + filter_t, + s_t, + pad_t, + ), + ) + blocks.append(block) + else: + if down_t > 0: + filter_t, pad_t = stride_t * 2, stride_t // 2 + block = nn.Conv1d(output_emb_width, width, 3, 1, 1) + blocks.append(block) + for i in range(down_t): + block = nn.Sequential( + Resnet1D( + width, + depth, + m_conv, + dilation_growth_rate, + dilation_cycle, + zero_out=zero_out, + res_scale=res_scale, + reverse_dilation=reverse_decoder_dilation, + checkpoint_res=checkpoint_res, + ), + nn.ConvTranspose1d( + width, + input_emb_width if i == (down_t - 1) else width, + filter_t, + stride_t, + pad_t, + ), + ) + blocks.append(block) + self.model = nn.Sequential(*blocks) + + def forward(self, x): + return self.model(x) + + +class Encoder(nn.Module): + def __init__( + self, + input_emb_width, + output_emb_width, + levels, + downs_t, + strides_t, + **block_kwargs, + ): + super().__init__() + self.input_emb_width = input_emb_width + self.output_emb_width = output_emb_width + self.levels = levels + self.downs_t = downs_t + self.strides_t = strides_t + + block_kwargs_copy = dict(**block_kwargs) + if "reverse_decoder_dilation" in block_kwargs_copy: + del block_kwargs_copy["reverse_decoder_dilation"] + level_block = lambda level, down_t, stride_t: EncoderConvBlock( + input_emb_width if level == 0 else output_emb_width, + output_emb_width, + down_t, + stride_t, + **block_kwargs_copy, + ) + self.level_blocks = nn.ModuleList() + iterator = zip(list(range(self.levels)), downs_t, strides_t) + for level, down_t, stride_t in iterator: + self.level_blocks.append(level_block(level, down_t, stride_t)) + + def forward(self, x): + N, T = x.shape[0], x.shape[-1] + emb = self.input_emb_width + assert_shape(x, (N, emb, T)) + xs = [] + + # 64, 32, ... + iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t) + for level, down_t, stride_t in iterator: + level_block = self.level_blocks[level] + x = level_block(x) + if type(stride_t) is tuple or type(stride_t) is list: + emb, T = self.output_emb_width, T // np.prod( + [s**d for s, d in zip(stride_t, down_t)] + ) + else: + emb, T = self.output_emb_width, T // (stride_t**down_t) + assert_shape(x, (N, emb, T)) + xs.append(x) + + return xs + + +class Decoder(nn.Module): + def __init__( + self, + input_emb_width, + output_emb_width, + levels, + downs_t, + strides_t, + **block_kwargs, + ): + super().__init__() + self.input_emb_width = input_emb_width + self.output_emb_width = output_emb_width + self.levels = levels + + self.downs_t = downs_t + + self.strides_t = strides_t + + level_block = lambda level, down_t, stride_t: DecoderConvBock( + output_emb_width, output_emb_width, down_t, stride_t, **block_kwargs + ) + self.level_blocks = nn.ModuleList() + iterator = zip(list(range(self.levels)), downs_t, strides_t) + for level, down_t, stride_t in iterator: + self.level_blocks.append(level_block(level, down_t, stride_t)) + + self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1) + + def forward(self, xs, all_levels=True): + if all_levels: + assert len(xs) == self.levels + else: + assert len(xs) == 1 + x = xs[-1] + N, T = x.shape[0], x.shape[-1] + emb = self.output_emb_width + assert_shape(x, (N, emb, T)) + + # 32, 64 ... + iterator = reversed( + list(zip(list(range(self.levels)), self.downs_t, self.strides_t)) + ) + for level, down_t, stride_t in iterator: + level_block = self.level_blocks[level] + x = level_block(x) + if type(stride_t) is tuple or type(stride_t) is list: + emb, T = self.output_emb_width, T * np.prod( + [s**d for s, d in zip(stride_t, down_t)] + ) + else: + emb, T = self.output_emb_width, T * (stride_t**down_t) + assert_shape(x, (N, emb, T)) + if level != 0 and all_levels: + x = x + xs[level - 1] + + x = self.out(x) + return x diff --git a/spiritlm/speech_tokenizer/hifigan/__init__.py b/spiritlm/speech_tokenizer/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0dfa702478efc001d6ce4e768e0a2fca46bb9a --- /dev/null +++ b/spiritlm/speech_tokenizer/hifigan/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from pathlib import Path +import os + +import torch + +from .hifigan_vocoder import HifiGANVocoder + +# Get the base checkpoints directory from environment variable or use the default base path +base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parents[3] / "checkpoints")) + +# Append 'speech_tokenizer' to the base path +CHECKPOINT_DIR = base_checkpoints_dir / "speech_tokenizer" + +CURRENT_DEVICE = ( + torch.device(torch.cuda.current_device()) + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" +) + + +def spiritlm_base_hifigan( + default_speaker=2, + default_style=8, # conv-default +): + return HifiGANVocoder( + CHECKPOINT_DIR / "hifigan_spiritlm_base/generator.pt", + default_speaker=default_speaker, + default_style=default_style, + ).to(CURRENT_DEVICE) + + +def spiritlm_expressive_hifigan_w2v2(default_speaker=2): + return HifiGANVocoder( + CHECKPOINT_DIR / "hifigan_spiritlm_expressive_w2v2/generator.pt", + default_speaker=default_speaker, + ).to(CURRENT_DEVICE) diff --git a/spiritlm/speech_tokenizer/hifigan/__pycache__/__init__.cpython-310.pyc b/spiritlm/speech_tokenizer/hifigan/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14be7ee5828ded5d40bd66af7b3fd5c8dff4dd51 Binary files /dev/null and b/spiritlm/speech_tokenizer/hifigan/__pycache__/__init__.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/hifigan/__pycache__/hifigan_vocoder.cpython-310.pyc b/spiritlm/speech_tokenizer/hifigan/__pycache__/hifigan_vocoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97aa3e216dc8a8807235f27e555cf2fd39be4630 Binary files /dev/null and b/spiritlm/speech_tokenizer/hifigan/__pycache__/hifigan_vocoder.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/hifigan/hifigan_vocoder.py b/spiritlm/speech_tokenizer/hifigan/hifigan_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e97e66958d6208b47c867391bf3772786dab51e3 --- /dev/null +++ b/spiritlm/speech_tokenizer/hifigan/hifigan_vocoder.py @@ -0,0 +1,546 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +# Standalone Hifigan vocoder +# Adapted from: +# - https://github.com/jik876/hifi-gan +# - https://github.com/facebookresearch/fairseq/tree/main/fairseq/models/text_to_speech +# - https://github.com/facebookresearch/speech-resynthesis/blob/main/examples/speech_to_speech_translation/models.py +# - https://github.com/facebookresearch/speech-resynthesis/blob/main/examples/expresso/models.py + +import json +import logging +from pathlib import Path +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, weight_norm + +_logger = logging.getLogger(__name__) + + +class HifiGANVocoder(nn.Module): + def __init__( + self, + checkpoint_path, + config_path=None, + default_speaker=0, + default_style=0, + fp16=False, + ): + super().__init__() + + if config_path is None: + config_path = Path(checkpoint_path).parent / "config.json" + with open(config_path) as f: + cfg = json.load(f) + self.vocoder = CodeHiFiGANVocoderModel(checkpoint_path, cfg, fp16) + self.vocoder.eval() + + self.multispkr = self.vocoder.model.multispkr + if self.multispkr: + self.default_speaker = default_speaker + speakers_path = Path(checkpoint_path).parent / "speakers.txt" + if speakers_path.exists(): + with open(speakers_path) as f: + self.speakers = [line.strip() for line in f] + _logger.info( + f"Loaded {len(self.speakers)} speakers. First few speakers: {self.speakers[:10]}" + ) + + self.multistyle = self.vocoder.model.multistyle + if self.multistyle: + self.default_style = default_style + styles_path = Path(checkpoint_path).parent / "styles.txt" + if styles_path.exists(): + with open(styles_path) as f: + self.styles = [line.strip() for line in f] + _logger.info( + f"Loaded {len(self.styles)} styles. First few styles: {self.styles[:10]}" + ) + + self.dur_pred = self.vocoder.model.dur_predictor is not None + self.cfg = cfg + + _logger.info( + f"HifiGAN: Duration Prediction = {self.dur_pred} - " + f"Multiple Speaker = {bool(self.multispkr)} - " + f"Multiple Style = {bool(self.multistyle)}" + ) + + # this is useful for determining the device + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) + + @property + def device(self): + return self._float_tensor.device + + def preprocess_code(self, code, deduplicate_code=False): + if isinstance(code, str): + code = code.split() + if isinstance(code, list): + code = list(map(int, code)) + code = torch.tensor(code) + elif isinstance(code, np.ndarray): + code = torch.from_numpy(code) + code = code.long() + if deduplicate_code: + code = torch.unique_consecutive(code) + return code.view(1, -1) + + def forward( + self, + code, + speaker_id=None, + style_id=None, + dur_pred=True, + f0_code=None, + style_code=None, + not_dedup_code=False, + ): + assert not ( + dur_pred and not self.dur_pred + ), "Model doesnt't support duration prediction" + inp = dict() + inp["code"] = self.preprocess_code(code, dur_pred and not not_dedup_code) + if f0_code is not None: + inp["f0_code"] = self.preprocess_code(f0_code, deduplicate_code=False) + if style_code is not None: + inp["style_code"] = self.preprocess_code(style_code, deduplicate_code=False) + if self.multispkr: + if speaker_id is None: + speaker_id = self.default_speaker + inp["spkr"] = torch.LongTensor([speaker_id]).view(1, 1) + if self.multistyle: + if style_id is None: + style_id = self.default_style + inp["style"] = torch.LongTensor([style_id]).view(1, 1) + inp = {k: v.to(self.device) for k, v in inp.items()} + return self.vocoder(inp, dur_pred) + + +class CodeHiFiGANVocoderModel(nn.Module): + def __init__( + self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False + ) -> None: + super().__init__() + self.model = CodeGenerator(model_cfg) + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + self.model.load_state_dict(state_dict["generator"]) + self.model.eval() + if fp16: + self.model.half() + self.model.remove_weight_norm() + _logger.info(f"Loaded CodeHiFiGAN checkpoint from '{checkpoint_path}'") + + def upsample(self, code, downsampled_code, uprate): + N = code.size(1) + K = downsampled_code.size(1) + assert abs(K * uprate - N) / uprate <= 1, (N, K, uprate) + upsampled_code = torch.repeat_interleave(downsampled_code, uprate, dim=1) + if upsampled_code.size(1) < N: + z = torch.zeros_like(code) + z[:, : upsampled_code.size(1)] = upsampled_code + z[:, upsampled_code.size(1) :] = upsampled_code[:, -1].view(-1, 1) + upsampled_code = z + upsampled_code = upsampled_code[:, :N] + return upsampled_code + + def forward(self, x: Dict[str, torch.Tensor], dur_prediction=False) -> torch.Tensor: + assert "code" in x + x["dur_prediction"] = dur_prediction + + # remove invalid code + mask = x["code"] >= 0 + x["code"] = x["code"][mask].unsqueeze(dim=0) + if "f0" in x: + f0_up_ratio = x["f0"].size(1) // x["code"].size(1) + mask = mask.unsqueeze(2).repeat(1, 1, f0_up_ratio).view(-1, x["f0"].size(1)) + x["f0"] = x["f0"][mask].unsqueeze(dim=0) + + # preprocess f0 & style codes + if "f0_code" in x: + if dur_prediction: # f0 must be upsampled first if dedup + assert len(x["f0_code"][0]) == len( + x["code"][0] + ), f"f0 must be upsampled first if dedup (f0_code length: {len(x['f0_code'][0])}, code length: {len(x['code'][0])})" + else: + x["f0_code"] = self.upsample( + x["code"], x["f0_code"], self.model.hubert_to_f0 + ) + + if "style_code" in x: + if dur_prediction: # style must be upsampled first if dedup + f"style must be upsampled first if dedup (style_code length: {len(x['style_code'][0])}, code length: {len(x['code'][0])})" + else: + x["style_code"] = self.upsample( + x["code"], x["style_code"], self.model.hubert_to_style + ) + + return self.model(**x).detach().squeeze() + + +# Higigan Generator +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + +class ResBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for layer in self.convs1: + remove_weight_norm(layer) + for layer in self.convs2: + remove_weight_norm(layer) + + +class Generator(nn.Module): + def __init__(self, cfg): + super(Generator, self).__init__() + self.num_kernels = len(cfg["resblock_kernel_sizes"]) + self.num_upsamples = len(cfg["upsample_rates"]) + self.conv_pre = weight_norm( + Conv1d( + cfg.get("model_in_dim", 80), + cfg["upsample_initial_channel"], + 7, + 1, + padding=3, + ) + ) + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"]) + ): + self.ups.append( + weight_norm( + ConvTranspose1d( + cfg["upsample_initial_channel"] // (2**i), + cfg["upsample_initial_channel"] // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = cfg["upsample_initial_channel"] // (2 ** (i + 1)) + for k, d in zip( + cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"] + ): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + _logger.info("Removing weight norm...") + for layer in self.ups: + remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class VariancePredictor(nn.Module): + def __init__( + self, + encoder_embed_dim, + var_pred_hidden_dim, + var_pred_kernel_size, + var_pred_dropout, + ): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv1d( + encoder_embed_dim, + var_pred_hidden_dim, + kernel_size=var_pred_kernel_size, + padding=(var_pred_kernel_size - 1) // 2, + ), + nn.ReLU(), + ) + self.ln1 = nn.LayerNorm(var_pred_hidden_dim) + self.dropout = var_pred_dropout + self.conv2 = nn.Sequential( + nn.Conv1d( + var_pred_hidden_dim, + var_pred_hidden_dim, + kernel_size=var_pred_kernel_size, + padding=1, + ), + nn.ReLU(), + ) + self.ln2 = nn.LayerNorm(var_pred_hidden_dim) + self.proj = nn.Linear(var_pred_hidden_dim, 1) + + def forward(self, x): + # Input: B x T x C; Output: B x T + x = self.conv1(x.transpose(1, 2)).transpose(1, 2) + x = F.dropout(self.ln1(x), p=self.dropout, training=self.training) + x = self.conv2(x.transpose(1, 2)).transpose(1, 2) + x = F.dropout(self.ln2(x), p=self.dropout, training=self.training) + return self.proj(x).squeeze(dim=2) + + +class CodeGenerator(Generator): + def __init__(self, cfg): + super().__init__(cfg) + self.dict = nn.Embedding(cfg["num_embeddings"], cfg["embedding_dim"]) + self.multispkr = cfg.get("multispkr", None) + self.embedder = cfg.get("embedder_params", None) + + self.f0_dict = None + if cfg.get("num_f0_tokens", None): + self.f0_dict = nn.Embedding(cfg["num_f0_tokens"], cfg["embedding_dim"]) + self.hubert_to_f0 = round( + cfg["f0_hop_size"] / cfg["code_hop_size"] + ) # 4 for 25hz hubert and 6.25hz f0 + + self.style_dict = None + if cfg.get("num_style_tokens", None): + self.style_dict = nn.Embedding( + cfg["num_style_tokens"], cfg["embedding_dim"] + ) + self.hubert_to_style = round( + cfg["style_hop_size"] / cfg["code_hop_size"] + ) # 25 for 25hz hubert and 1hz style + + self.multistyle = cfg.get("multistyle", None) + + if self.multispkr and not self.embedder: + self.spkr = nn.Embedding(cfg.get("num_speakers", 200), cfg["embedding_dim"]) + elif self.embedder: + self.spkr = nn.Linear(cfg.get("embedder_dim", 256), cfg["embedding_dim"]) + + if self.multistyle: + self.style = nn.Embedding(cfg.get("num_styles", 100), cfg["embedding_dim"]) + + self.dur_predictor = None + if cfg.get("dur_predictor_params", None): + self.dur_predictor = VariancePredictor( + cfg["dur_predictor_params"]["encoder_embed_dim"], + cfg["dur_predictor_params"]["var_pred_hidden_dim"], + cfg["dur_predictor_params"]["var_pred_kernel_size"], + cfg["dur_predictor_params"]["var_pred_dropout"], + ) + + self.f0 = cfg.get("f0", None) + n_f0_bin = cfg.get("f0_quant_num_bin", 0) + self.f0_quant_embed = ( + None if n_f0_bin <= 0 else nn.Embedding(n_f0_bin, cfg["embedding_dim"]) + ) + + @staticmethod + def _upsample(signal, max_frames): + if signal.dim() == 3: + bsz, channels, cond_length = signal.size() + elif signal.dim() == 2: + signal = signal.unsqueeze(2) + bsz, channels, cond_length = signal.size() + else: + signal = signal.view(-1, 1, 1) + bsz, channels, cond_length = signal.size() + + signal = signal.unsqueeze(3).repeat(1, 1, 1, max_frames // cond_length) + + # pad zeros as needed (if signal's shape does not divide completely with max_frames) + reminder = (max_frames - signal.shape[2] * signal.shape[3]) // signal.shape[3] + if reminder > 0: + raise NotImplementedError( + "Padding condition signal - misalignment between condition features." + ) + + signal = signal.view(bsz, channels, max_frames) + return signal + + def forward(self, **kwargs): + x = self.dict(kwargs["code"]).transpose(1, 2) + + dur_out = None + if self.dur_predictor and kwargs.get("dur_prediction", False): + assert x.size(0) == 1, "only support single sample" + log_dur_pred = self.dur_predictor(x.transpose(1, 2)) + dur_out = torch.clamp( + torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1 + ) + # B x C x T + x = torch.repeat_interleave(x, dur_out.view(-1), dim=2) + + if self.f0: + if self.f0_quant_embed: + kwargs["f0"] = self.f0_quant_embed(kwargs["f0"].long()).transpose(1, 2) + else: + kwargs["f0"] = kwargs["f0"].unsqueeze(1) + + if x.shape[-1] < kwargs["f0"].shape[-1]: + x = self._upsample(x, kwargs["f0"].shape[-1]) + elif x.shape[-1] > kwargs["f0"].shape[-1]: + kwargs["f0"] = self._upsample(kwargs["f0"], x.shape[-1]) + x = torch.cat([x, kwargs["f0"]], dim=1) + + if self.f0_dict is not None: + f0 = self.f0_dict(kwargs["f0_code"]).transpose(1, 2) # B, C, T + if dur_out is not None: + f0 = torch.repeat_interleave(f0, dur_out.view(-1), dim=2) + x = torch.cat([x, f0], dim=1) # B, 2C, T + + if self.style_dict is not None: + style = self.style_dict(kwargs["style_code"]).transpose(1, 2) # B, C, T + if dur_out is not None: + style = torch.repeat_interleave(style, dur_out.view(-1), dim=2) + x = torch.cat([x, style], dim=1) # B, 2C, T + + if self.multispkr: + assert ( + "spkr" in kwargs + ), 'require "spkr" input for multispeaker CodeHiFiGAN vocoder' + spkr = self.spkr(kwargs["spkr"]).transpose(1, 2) + spkr = self._upsample(spkr, x.shape[-1]) + x = torch.cat([x, spkr], dim=1) + + if self.multistyle: + assert ( + "style" in kwargs + ), 'require "style" input for multispeaker CodeHiFiGAN vocoder' + style = self.style(kwargs["style"]).transpose(1, 2) + style = self._upsample(style, x.shape[-1]) + x = torch.cat([x, style], dim=1) + + for k, feat in kwargs.items(): + if k in [ + "spkr", + "code", + "f0", + "dur_prediction", + "style", + "f0_code", + "style_code", + ]: + continue + + feat = self._upsample(feat, x.shape[-1]) + x = torch.cat([x, feat], dim=1) + + return super().forward(x) diff --git a/spiritlm/speech_tokenizer/hubert/__init__.py b/spiritlm/speech_tokenizer/hubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2faf21fc87a398ddb221b2f25591e0daec77c7 --- /dev/null +++ b/spiritlm/speech_tokenizer/hubert/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from pathlib import Path +import os + +import torch + +from .hubert_tokenizer import HubertTokenizer + +# Get the base checkpoints directory from environment variable or use the default base path +base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parents[3] / "checkpoints")) + +# Append 'speech_tokenizer' to the base path +CHECKPOINT_DIR = base_checkpoints_dir / "speech_tokenizer" + +CURRENT_DEVICE = ( + torch.device(torch.cuda.current_device()) + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" +) + + +def spiritlm_hubert(): + return HubertTokenizer( + hubert_ckpt=CHECKPOINT_DIR / "hubert_25hz/mhubert_base_25hz.pt", + hubert_layer=11, + quantizer_ckpt=CHECKPOINT_DIR / "hubert_25hz/L11_quantizer_500.pt", + is_linear_quantizer=True, + ).to(CURRENT_DEVICE) diff --git a/spiritlm/speech_tokenizer/hubert/__pycache__/__init__.cpython-310.pyc b/spiritlm/speech_tokenizer/hubert/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b423e826a2c641bfcb70aceac647403fa071bfc3 Binary files /dev/null and b/spiritlm/speech_tokenizer/hubert/__pycache__/__init__.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/hubert/__pycache__/hubert_tokenizer.cpython-310.pyc b/spiritlm/speech_tokenizer/hubert/__pycache__/hubert_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c61d53409f42128cdd2f33858944cc2c0ead85c Binary files /dev/null and b/spiritlm/speech_tokenizer/hubert/__pycache__/hubert_tokenizer.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/hubert/__pycache__/quantizer_model.cpython-310.pyc b/spiritlm/speech_tokenizer/hubert/__pycache__/quantizer_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffcd046551136046b0790b8c01020da701add447 Binary files /dev/null and b/spiritlm/speech_tokenizer/hubert/__pycache__/quantizer_model.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/hubert/hubert_model/__init__.py b/spiritlm/speech_tokenizer/hubert/hubert_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8b770b11f3e4f9294d89a9e0626e4eb6e2af8f --- /dev/null +++ b/spiritlm/speech_tokenizer/hubert/hubert_model/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from .hubert_model import * diff --git a/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/__init__.cpython-310.pyc b/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60a88e14db0e6be916ce8120c07321ce838be796 Binary files /dev/null and b/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/__init__.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/hubert_model.cpython-310.pyc b/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/hubert_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73c03896dd669a815896a48751912c62d39ba314 Binary files /dev/null and b/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/hubert_model.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/wav2vec2_model.cpython-310.pyc b/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/wav2vec2_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2d144dec8197fb3aa90335a9f1fb30abbf72dd0 Binary files /dev/null and b/spiritlm/speech_tokenizer/hubert/hubert_model/__pycache__/wav2vec2_model.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/hubert/hubert_model/hubert_model.py b/spiritlm/speech_tokenizer/hubert/hubert_model/hubert_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0542f92af8a3f6f352a753ba797b44036b4e7624 --- /dev/null +++ b/spiritlm/speech_tokenizer/hubert/hubert_model/hubert_model.py @@ -0,0 +1,685 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +# This file was copied and adapted from the s3prl package: +# - https://github.com/s3prl/s3prl/blob/main/s3prl/upstream/hubert/hubert_model.py +# which was adapted from fairseq to remove the dependency on the entire fairseq package + + +import logging +from dataclasses import dataclass, field, is_dataclass +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn + +from .wav2vec2_model import ( + EXTRACTOR_MODE_CHOICES, + LAYER_TYPE_CHOICES, + MASKING_DISTRIBUTION_CHOICES, + ChoiceEnum, + ConvFeatureExtractionModel, + GradMultiply, + LayerNorm, + TransformerEncoder, + compute_mask_indices, + get_available_activation_fns, +) + +_logger = logging.getLogger(__name__) + + +@dataclass +class HubertPretrainingConfig: + label_rate: float = field( + default=-1.0, + metadata={"help": "label frame rate. -1.0 for sequence label"}, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + +@dataclass +class HubertConfig: + label_rate: float + + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + + # dropouts + dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for the transformer"}, + ) + attention_dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for attention weights"}, + ) + activation_dropout: float = field( + default=0.0, + metadata={"help": "dropout probability after activation in FFN"}, + ) + encoder_layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a tarnsformer layer"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + }, + ) + untie_final_proj: bool = field( + default=False, + metadata={"help": "use separate projection for each target"}, + ) + layer_norm_first: bool = field( + default=False, + metadata={"help": "apply layernorm first in the transformer"}, + ) + conv_feature_layers: str = field( + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + metadata={ + "help": "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply feature extractor var grads by this"}, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + conv_pos_batch_norm: bool = field( + default=False, + metadata={ + "help": "use batch norm instead of weight norm in conv_pos (for bf16 models)" + }, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={"help": "legacy (to be removed)"}, + ) + + # loss computation + skip_masked: bool = field( + default=False, + metadata={"help": "skip computing losses over masked frames"}, + ) + skip_nomask: bool = field( + default=False, + metadata={"help": "skip computing losses over unmasked frames"}, + ) + + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + + # FP16 optimization + required_seq_len_multiple: int = field( + default=2, + metadata={ + "help": "pad the input to encoder such that the sequence length is divisible by multiple" + }, + ) + + # Conformer + depthwise_conv_kernel_size: int = field( + default=31, + metadata={ + "help": "depthwise-conv-kernel-size for convolution in conformer layer" + }, + ) + attn_type: str = field( + default="", + metadata={"help": "if espnet use ESPNET MHA"}, + ) + pos_enc_type: str = field( + default="abs", + metadata={"help": "Positional encoding type to use in conformer"}, + ) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) + + +class HubertModel(torch.nn.Module): + def __init__( + self, + cfg: HubertConfig, + task_cfg: HubertPretrainingConfig, + dictionaries: Optional[List[Any]] = None, + ) -> None: + super().__init__() + _logger.info(f"HubertModel Config: {cfg}") + + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.untie_final_proj = cfg.untie_final_proj + if dictionaries is None or len(dictionaries) == 0: + self.final_proj = None + self.label_embs_concat = None + _logger.info( + "cannot find any dictionary. assume will be used for inference." + ) + else: + if self.untie_final_proj: + self.final_proj = nn.Linear( + cfg.encoder_embed_dim, final_dim * len(dictionaries) + ) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + # modules below are not needed during fine-tuning + if any([d is None for d in dictionaries]): + _logger.info( + "cannot find dictionary. assume will be used for fine-tuning" + ) + else: + self.num_classes = [len(d) for d in dictionaries] + self.label_embs_concat = nn.Parameter( + torch.FloatTensor(sum(self.num_classes), final_dim) + ) + nn.init.uniform_(self.label_embs_concat) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: HubertConfig, task): + """Build a new model instance.""" + + model = HubertModel(cfg, task.cfg, task.dictionaries) + return model + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits + + def forward_features(self, source: torch.Tensor) -> torch.Tensor: + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets( + self, + features: torch.Tensor, + target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + return padding_mask + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ) -> Dict[str, torch.Tensor]: + """output layer is 1-based""" + features = self.forward_features(source) + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1, + ) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + def compute_pred(proj_x, target, label_embs): + # compute logits for the i-th label set + y = torch.index_select(label_embs, 0, target.long()) + negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + # proj_x: (S, D) + # y: (S, D) + # negs: (Neg, S, D) + return self.compute_nce(proj_x, y, negs) + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + if self.untie_final_proj: + proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) + else: + proj_x_m_list = [proj_x_m for _ in range(len(target_list))] + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) + ] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + if self.untie_final_proj: + proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) + else: + proj_x_u_list = [proj_x_u for _ in range(len(target_list))] + + logit_u_list = [ + compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) + for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list)) + ] + else: + logit_u_list = [None for _ in target_list] + + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + return result + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None + + +def merge_with_parent(dc: dataclass, cfg: dict): + from copy import deepcopy + + assert is_dataclass(dc) + assert type(cfg) == dict + cfg = deepcopy(cfg) + + def fix_cfg(cfg): + target_keys = set(dc.__dataclass_fields__.keys()) + for k in list(cfg.keys()): + if k not in target_keys: + del cfg[k] + + fix_cfg(cfg) + assert len(cfg) > 0 + return dc(**cfg) + + +def load_hubert_model(ckpt_path, ckpt_type="converted", inference_only=True): + """ + Note: If ckpt_type is 'fairseq', torch.load() will require fairseq module + """ + assert ckpt_type in ["fairseq", "converted"] + + states = torch.load(ckpt_path, map_location="cpu", weights_only=True) + if ckpt_type == "fairseq": + # old models + if "cfg" not in states: + model_args = task_args = states["args"] + else: + model_args = states["cfg"]["model"] + task_args = states["cfg"]["task"] + model_cfg = merge_with_parent(HubertConfig, vars(model_args)) + task_cfg = merge_with_parent(HubertPretrainingConfig, vars(task_args)) + model_state = states["model"] + dictionaries = [ + dictionary.symbols for dictionary in states["task_state"]["dictionaries"] + ] + else: + model_cfg = merge_with_parent(HubertConfig, states["model_cfg"]) + task_cfg = merge_with_parent(HubertPretrainingConfig, states["task_cfg"]) + model_state = states["model_weight"] + dictionaries = states["dictionaries_symbols"] + + # Removing unnessary states + if inference_only: + dictionaries = None + for key in ["label_embs_concat", "final_proj.weight", "final_proj.bias"]: + model_state.pop(key, None) + + model = HubertModel(model_cfg, task_cfg, dictionaries=dictionaries) + model.load_state_dict(model_state, strict=False) + + _logger.info(f"Loaded HubertModel from '{ckpt_path}'") + + return model, model_cfg, task_cfg diff --git a/spiritlm/speech_tokenizer/hubert/hubert_model/wav2vec2_model.py b/spiritlm/speech_tokenizer/hubert/hubert_model/wav2vec2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..32282699721a84602999fa54097452026c1c5ae8 --- /dev/null +++ b/spiritlm/speech_tokenizer/hubert/hubert_model/wav2vec2_model.py @@ -0,0 +1,3320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +# This file was copied from the s3prl package: +# - https://github.com/s3prl/s3prl/blob/main/s3prl/upstream/wav2vec2/wav2vec2_model.py +# which was adapted from fairseq to remove the dependency on the entire fairseq package + +import logging +import math +import uuid +from dataclasses import dataclass, field +from enum import Enum, EnumMeta +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +_logger = logging.getLogger(__name__) + + +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat( + (-x2, x1), dim=x1.ndim - 1 + ) # dim=-1 triggers a bug in earlier torch versions + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +class RotaryPositionalEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ + Paper: https://arxiv.org/pdf/2104.09864.pdf + Args: + dim: Dimension of embedding + base: Base value for exponential + precision: precision to use for numerical values + """ + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def forward(self, x, seq_len=None): + """ + Args: + x: Input x with T X B X C + seq_len: Sequence length of input x + """ + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos()[:, None, None, :] + self.sin_cached = emb.sin()[:, None, None, :] + return self.cos_cached, self.sin_cached + + +class ESPNETMultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + Args: + n_head: The number of heads. + n_feat: The number of features. + dropout: Dropout rate. + """ + + def __init__(self, n_feat, n_head, dropout): + """Construct an MultiHeadedAttention object.""" + super(ESPNETMultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward_qkv(self, query, key, value, **kwargs): + """Transform query, key and value. + Args: + query: Query tensor B X T1 X C + key: Key tensor B X T2 X C + value: Value tensor B X T2 X C + Returns: + torch.Tensor: Transformed query tensor B X n_head X T1 X d_k + torch.Tensor: Transformed key tensor B X n_head X T2 X d_k + torch.Tensor: Transformed value tensor B X n_head X T2 X d_k + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + Args: + value: Transformed value B X n_head X T2 X d_k. + scores: Attention score B X n_head X T1 X T2 + mask: Mask T2 X B + Returns: + torch.Tensor: Transformed value B X T1 X d_model + weighted by the attention score B X T1 X T2 + """ + n_batch = value.size(0) + if mask is not None: + scores = scores.masked_fill( + mask.unsqueeze(1).unsqueeze(2).to(bool), + float("-inf"), # (batch, head, time1, time2) + ) + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, key_padding_mask=None, **kwargs): + """Compute scaled dot product attention. + Args: + query (torch.Tensor): Query tensor T X B X C + key (torch.Tensor): Key tensor T X B X C + value (torch.Tensor): Value tensor T X B X C + mask (torch.Tensor): Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X D. + """ + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None + + +class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head: The number of heads. + n_feat: The number of features. + dropout: Dropout rate. + zero_triu: Whether to zero the upper triangular part of attention matrix. + """ + + def __init__(self, n_feat, n_head, dropout, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_feat, n_head, dropout) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + Args: + x: Input tensor B X n_head X T X 2T-1 + Returns: + torch.Tensor: Output tensor. + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs): + """Compute scaled dot product attention. + Args: + query: Query tensor T X B X C + key: Key tensor T X B X C + value: Value tensor T X B X C + pos_emb: Positional embedding tensor B X 2T-1 X C + key_padding_mask: Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X C. + """ + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + pos_emb = pos_emb.transpose(0, 1) + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None + + +class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): + def __init__( + self, + n_feat, + n_head, + dropout, + precision, + rotary_emd_base=10000, + ): + """Construct an RotaryPositionMultiHeadedAttention object.""" + super().__init__(n_feat, n_head, dropout) + precision = torch.float + self.rotary_ndims = self.d_k # also try self.d_k//2 + if precision == "fp16": + precision = torch.half + + self.rotary_emb = RotaryPositionalEmbedding( + self.rotary_ndims, base=rotary_emd_base, precision=precision + ) + + def forward(self, query, key, value, key_padding_mask=None, **kwargs): + """Compute rotary position attention. + Args: + query: Query tensor T X B X C + key: Key tensor T X B X C + value: Value tensor T X B X C + key_padding_mask: Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X D. + Notes: + Assumes self attn + """ + + T, B, C = value.size() + query = query.view(T, B, self.h, self.d_k) + key = key.view(T, B, self.h, self.d_k) + value = value.view(T, B, self.h, self.d_k) + cos, sin = self.rotary_emb(value, seq_len=T) + query, key = apply_rotary_pos_emb( + query, key, cos, sin, offset=0 + ) # offset is based on layer_past + + query = query.view(T, B, self.h * self.d_k) + key = key.view(T, B, self.h * self.d_k) + value = value.view(T, B, self.h * self.d_k) + + # TBD to BTD + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None + + +class ConvolutionModule(torch.nn.Module): + """Convolution block used in the conformer block""" + + def __init__( + self, + embed_dim, + channels, + depthwise_kernel_size, + dropout, + activation_fn="swish", + bias=False, + export=False, + ): + """ + Args: + embed_dim: Embedding dimension + channels: Number of channels in depthwise conv layers + depthwise_kernel_size: Depthwise conv layer kernel size + dropout: dropout value + activation_fn: Activation function to use after depthwise convolution kernel + bias: If bias should be added to conv layers + export: If layernorm should be exported to jit + """ + super(ConvolutionModule, self).__init__() + assert ( + depthwise_kernel_size - 1 + ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" + self.layer_norm = LayerNorm(embed_dim, export=export) + self.pointwise_conv1 = torch.nn.Conv1d( + embed_dim, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.glu = torch.nn.GLU(dim=1) + self.depthwise_conv = torch.nn.Conv1d( + channels, + channels, + depthwise_kernel_size, + stride=1, + padding=(depthwise_kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.batch_norm = torch.nn.BatchNorm1d(channels) + self.activation = get_activation_fn(activation_fn)(channels) + self.pointwise_conv2 = torch.nn.Conv1d( + channels, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, x): + """ + Args: + x: Input of shape B X T X C + Returns: + Tensor of shape B X T X C + """ + x = self.layer_norm(x) + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = self.glu(x) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.batch_norm(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) + x = self.dropout(x) + return x.transpose(1, 2) + + +class FeedForwardModule(torch.nn.Module): + """Positionwise feed forward layer used in conformer""" + + def __init__( + self, + input_feat, + hidden_units, + dropout1, + dropout2, + activation_fn="swish", + bias=True, + ): + """ + Args: + input_feat: Input feature dimension + hidden_units: Hidden unit dimension + dropout1: dropout value for layer1 + dropout2: dropout value for layer2 + activation_fn: Name of activation function + bias: If linear layers should have bias + """ + + super(FeedForwardModule, self).__init__() + self.layer_norm = LayerNorm(input_feat) + self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias) + self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias) + self.dropout1 = torch.nn.Dropout(dropout1) + self.dropout2 = torch.nn.Dropout(dropout2) + self.activation = get_activation_fn(activation_fn)(hidden_units) + + def forward(self, x): + """ + Args: + x: Input Tensor of shape T X B X C + Returns: + Tensor of shape T X B X C + """ + x = self.layer_norm(x) + x = self.w_1(x) + x = self.activation(x) + x = self.dropout1(x) + x = self.w_2(x) + return self.dropout2(x) + + +class ConformerEncoderLayer(torch.nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100. We currently don't support relative positional encoding in MHA""" + + def __init__( + self, + embed_dim, + ffn_embed_dim, + attention_heads, + dropout, + use_fp16, + depthwise_conv_kernel_size=31, + activation_fn="swish", + attn_type=None, + pos_enc_type="abs", + ): + """ + Args: + embed_dim: Input embedding dimension + ffn_embed_dim: FFN layer dimension + attention_heads: Number of attention heads in MHA + dropout: dropout value + depthwise_conv_kernel_size: Size of kernel in depthwise conv layer in convolution module + activation_fn: Activation function name to use in convulation block and feed forward block + attn_type: MHA implementation from ESPNET vs fairseq + pos_enc_type: Positional encoding type - abs, rope, rel_pos + """ + self.pos_enc_type = pos_enc_type + super(ConformerEncoderLayer, self).__init__() + + self.ffn1 = FeedForwardModule( + embed_dim, + ffn_embed_dim, + dropout, + dropout, + ) + + self.self_attn_layer_norm = LayerNorm(embed_dim, export=False) + self.self_attn_dropout = torch.nn.Dropout(dropout) + if attn_type == "espnet": + if self.pos_enc_type == "rel_pos": + self.self_attn = RelPositionMultiHeadedAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + elif self.pos_enc_type == "rope": + self.self_attn = RotaryPositionMultiHeadedAttention( + embed_dim, attention_heads, dropout=dropout, precision=use_fp16 + ) + elif self.pos_enc_type == "abs": + self.self_attn = ESPNETMultiHeadedAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + else: + raise Exception(f"Unsupported attention type {self.pos_enc_type}") + else: + # Default to fairseq MHA + self.self_attn = MultiheadAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + + self.conv_module = ConvolutionModule( + embed_dim=embed_dim, + channels=embed_dim, + depthwise_kernel_size=depthwise_conv_kernel_size, + dropout=dropout, + activation_fn=activation_fn, + ) + + self.ffn2 = FeedForwardModule( + embed_dim, + ffn_embed_dim, + dropout, + dropout, + activation_fn=activation_fn, + ) + self.final_layer_norm = LayerNorm(embed_dim, export=False) + + def forward( + self, + x, + encoder_padding_mask: Optional[torch.Tensor], + position_emb: Optional[torch.Tensor] = None, + ): + """ + Args: + x: Tensor of shape T X B X C + encoder_padding_mask: Optional mask tensor + positions: + Returns: + Tensor of shape T X B X C + """ + residual = x + x = self.ffn1(x) + x = x * 0.5 + residual + residual = x + x = self.self_attn_layer_norm(x) + if self.pos_enc_type == "rel_pos": + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + pos_emb=position_emb, + need_weights=False, + ) + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + need_weights=False, + ) + x = self.self_attn_dropout(x) + x = x + residual + + residual = x + # TBC to BTC + x = x.transpose(0, 1) + x = self.conv_module(x) + # BTC to TBC + x = x.transpose(0, 1) + x = residual + x + + residual = x + x = self.ffn2(x) + + layer_result = x + + x = x * 0.5 + residual + + x = self.final_layer_norm(x) + return x, (attn, layer_result) + + +class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer): + """Encoder layer for Wav2vec2 encoder""" + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + position_emb=None, + ): + return super().forward(x, self_attn_padding_mask, position_emb) + + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls + + +class FairseqDropout(nn.Module): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace: bool = False): + if self.p > 0 and (self.training or self.apply_during_inference): + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + def make_generation_fast_( + self, + name: str, + retain_dropout: bool = False, + retain_dropout_modules: Optional[List[str]] = None, + **kwargs, + ): + if retain_dropout: + if retain_dropout_modules is not None and self.module_name is None: + _logger.warning( + "Cannot enable dropout during inference for module {} " + "because module_name was not set".format(name) + ) + elif ( + retain_dropout_modules is None # if None, apply to all modules + or self.module_name in retain_dropout_modules + ): + _logger.info( + "Enabling dropout during inference for module: {}".format(name) + ) + self.apply_during_inference = True + else: + _logger.info("Disabling dropout for module: {}".format(name)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +@with_incremental_state +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + # TODO: pass in config rather than string. + # config defined in xformers.components.attention.AttentionConfig + xformers_att_config: Optional[str] = None, + xformers_blocksparse_layout: Optional[ + torch.Tensor + ] = None, # This should be part of the config + xformers_blocksparse_blocksize: Optional[ + int + ] = 16, # This should be part of the config + ): + super().__init__() + + def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + xformers_att_config = eval_str_dict(xformers_att_config) + self.use_xformers = xformers_att_config is not None + assert not self.use_xformers, "Do not use xformers in S3PRL" + + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + self.beam_size = 1 + self.reset_parameters() + + self.onnx_trace = False + self.skip_embed_dim_check = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def _get_reserve_head_index(self, num_heads_to_keep: int): + k_proj_heads_norm = [] + q_proj_heads_norm = [] + v_proj_heads_norm = [] + + for i in range(self.num_heads): + start_idx = i * self.head_dim + end_idx = (i + 1) * self.head_dim + k_proj_heads_norm.append( + torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx,])).tolist() + + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() + ) + q_proj_heads_norm.append( + torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx,])).tolist() + + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() + ) + v_proj_heads_norm.append( + torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx,])).tolist() + + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() + ) + + heads_norm = [] + for i in range(self.num_heads): + heads_norm.append( + k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] + ) + + sorted_head_index = sorted( + range(self.num_heads), key=lambda k: heads_norm[k], reverse=True + ) + reserve_head_index = [] + for i in range(num_heads_to_keep): + start = sorted_head_index[i] * self.head_dim + end = (sorted_head_index[i] + 1) * self.head_dim + reserve_head_index.append((start, end)) + return reserve_head_index + + def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): + new_q_weight = [] + new_q_bias = [] + new_k_weight = [] + new_k_bias = [] + new_v_weight = [] + new_v_bias = [] + new_out_proj_weight = [] + + for ele in reserve_head_index: + start_idx, end_idx = ele + new_q_weight.append(self.q_proj.weight[start_idx:end_idx,]) + new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) + + new_k_weight.append(self.k_proj.weight[start_idx:end_idx,]) + + new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) + + new_v_weight.append(self.v_proj.weight[start_idx:end_idx,]) + new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) + + new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) + + new_q_weight = torch.cat(new_q_weight).detach() + new_k_weight = torch.cat(new_k_weight).detach() + new_v_weight = torch.cat(new_v_weight).detach() + new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach() + new_q_weight.requires_grad = True + new_k_weight.requires_grad = True + new_v_weight.requires_grad = True + new_out_proj_weight.requires_grad = True + + new_q_bias = torch.cat(new_q_bias).detach() + new_q_bias.requires_grad = True + + new_k_bias = torch.cat(new_k_bias).detach() + new_k_bias.requires_grad = True + + new_v_bias = torch.cat(new_v_bias).detach() + new_v_bias.requires_grad = True + + self.q_proj.weight = torch.nn.Parameter(new_q_weight) + self.q_proj.bias = torch.nn.Parameter(new_q_bias) + + self.k_proj.weight = torch.nn.Parameter(new_k_weight) + self.k_proj.bias = torch.nn.Parameter(new_k_bias) + + self.v_proj.weight = torch.nn.Parameter(new_v_weight) + self.v_proj.bias = torch.nn.Parameter(new_v_bias) + + self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight) + + self.num_heads = len(reserve_head_index) + self.embed_dim = self.head_dim * self.num_heads + self.q_proj.out_features = self.embed_dim + self.k_proj.out_features = self.embed_dim + self.v_proj.out_features = self.embed_dim + + def _set_skip_embed_dim_check(self): + self.skip_embed_dim_check = True + + def _pad_masks( + self, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + if attn_mask is not None: + shape = attn_mask.size()[:-1] + torch.Size([1]) + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) + if key_padding_mask is not None: + shape = key_padding_mask.size()[:-1] + torch.Size([1]) + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(shape), + ], + dim=-1, + ) + return key_padding_mask, attn_mask + + def _add_bias( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + bsz: int, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + assert self.bias_k is not None + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _append_zero_attn( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 + ) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + if not self.skip_embed_dim_check: + assert ( + embed_dim == self.embed_dim + ), f"query dim {embed_dim} != {self.embed_dim}" + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert value is not None + assert src_len, key_bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + # The Multihead attention implemented in pytorch forces strong dimension check + # for input embedding dimention and K,Q,V projection dimension. + # Since pruning will break the dimension check and it is not easy to modify the pytorch API, + # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check + and not self.skip_embed_dim_check + ): + assert key is not None and value is not None + + if self.use_xformers: + return self._xformers_attn_forward( + query, key, value, key_padding_mask, need_weights, attn_mask + ) + + else: + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + if self.beam_size > 1 and bsz == key.size(1): + # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ + :, :, 0, : + ] + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view( + -1, self.beam_size, key_padding_mask.size(1) + )[:, 0, :] + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + kv_bsz = bsz # need default value for scripting + if k is not None: + kv_bsz = k.size(1) + k = ( + k.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + kv_bsz = _prev_key.size(0) + prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + assert kv_bsz == _prev_value.size(0) + prev_value = _prev_value.view( + kv_bsz * self.num_heads, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=kv_bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view( + kv_bsz, self.num_heads, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == kv_bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = torch.einsum( + "bxhtd,bhsd->bxhts", + q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), + k.view((kv_bsz, self.num_heads) + k.size()[1:]), + ) + attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) + else: + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.view( + kv_bsz, -1, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + def softmax_supporting_onnx_trace(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + attn_weights_float = softmax_supporting_onnx_trace( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + if self.encoder_decoder_attention and bsz != kv_bsz: + attn = torch.einsum( + "bxhts,bhsd->bxhtd", + attn_probs.view( + ( + kv_bsz, + -1, + self.num_heads, + ) + + attn_probs.size()[1:] + ), + v.view( + ( + kv_bsz, + self.num_heads, + ) + + v.size()[1:] + ), + ) + attn = attn.reshape((-1,) + attn.size()[-2:]) + else: + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention: + if input_buffer_k.size(0) * self.beam_size == new_order.size(0): + return incremental_state + elif self.beam_size > 1: + input_buffer[k] = input_buffer_k.index_select( + 0, + new_order.reshape(-1, self.beam_size)[:, 0] + // self.beam_size, + ) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def set_beam_size(self, beam_size): + """Used for effiecient beamable enc-dec attention""" + self.beam_size = beam_size + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value + + +class RelPositionalEncoding(nn.Module): + """Relative positional encoding module (new implementation). + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + """ + + def __init__(self, max_len, d_model): + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i 1 + weight_proj_depth: number of layers (with activation in between) to project input before computing logits + weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of + projections by this factor + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.time_first = time_first + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) + nn.init.uniform_(self.vars) + + if weight_proj_depth > 1: + + def block(input_dim, output_dim): + return nn.Sequential(nn.Linear(input_dim, output_dim), activation) + + inner_dim = self.input_dim * weight_proj_factor + self.weight_proj = nn.Sequential( + *[ + block(self.input_dim if i == 0 else inner_dim, inner_dim) + for i in range(weight_proj_depth - 1) + ], + nn.Linear(inner_dim, groups * num_vars), + ) + else: + self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.init.normal_(self.weight_proj.weight, mean=0, std=1) + nn.init.zeros_(self.weight_proj.bias) + + if isinstance(temp, str): + import ast + + temp = ast.literal_eval(temp) + assert len(temp) == 3, f"{temp}, {len(temp)}" + + self.max_temp, self.min_temp, self.temp_decay = temp + self.curr_temp = self.max_temp + self.codebook_indices = None + + def set_num_updates(self, num_updates): + self.curr_temp = max( + self.max_temp * self.temp_decay**num_updates, self.min_temp + ) + + def get_codebook_indices(self): + if self.codebook_indices is None: + from itertools import product + + p = [range(self.num_vars)] * self.groups + inds = list(product(*p)) + self.codebook_indices = torch.tensor( + inds, dtype=torch.long, device=self.vars.device + ).flatten() + + if not self.combine_groups: + self.codebook_indices = self.codebook_indices.view( + self.num_vars**self.groups, -1 + ) + for b in range(1, self.groups): + self.codebook_indices[:, b] += self.num_vars * b + self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + + def codebook(self): + indices = self.get_codebook_indices() + return ( + self.vars.squeeze(0) + .index_select(0, indices) + .view(self.num_vars**self.groups, -1) + ) + + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.view(-1, self.groups) + cb_size = indices.size(0) + assert ( + n < cb_size + ), f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) + return z + + def to_codebook_index(self, indices): + res = indices.new_full(indices.shape[:-1], 0) + for i in range(self.groups): + exponent = self.groups - i - 1 + res += indices[..., i] * (self.num_vars**exponent) + return res + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars * self.groups} + + if not self.time_first: + x = x.transpose(1, 2) + + bsz, tsz, fsz = x.shape + x = x.reshape(-1, fsz) + x = self.weight_proj(x) + x = x.view(bsz * tsz * self.groups, -1) + + _, k = x.max(-1) + hard_x = ( + x.new_zeros(*x.shape) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + avg_probs = torch.softmax( + x.view(bsz * tsz, self.groups, -1).float(), dim=-1 + ).mean(dim=0) + result["prob_perplexity"] = torch.exp( + -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) + ).sum() + + result["temp"] = self.curr_temp + + if self.training: + x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x) + else: + x = hard_x + + x = x.view(bsz * tsz, -1) + + vars = self.vars + if self.combine_groups: + vars = vars.repeat(1, self.groups, 1) + + if produce_targets: + result["targets"] = ( + x.view(bsz * tsz * self.groups, -1) + .argmax(dim=-1) + .view(bsz, tsz, self.groups) + .detach() + ) + + x = x.unsqueeze(-1) * vars + x = x.view(bsz * tsz, self.groups, self.num_vars, -1) + x = x.sum(-2) + x = x.view(bsz, tsz, -1) + + if not self.time_first: + x = x.transpose(1, 2) # BTC -> BCT + + result["x"] = x + + return result + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class StrEnumMeta(EnumMeta): + # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see + # https://github.com/facebookresearch/hydra/issues/1156 + @classmethod + def __instancecheck__(cls, other): + return "enum" in str(type(other)) + + +class StrEnum(Enum, metaclass=StrEnumMeta): + def __str__(self): + return self.value + + def __eq__(self, other: str): + return self.value == other + + def __repr__(self): + return self.value + + def __hash__(self): + return hash(str(self)) + + +def ChoiceEnum(choices: List[str]): + """return the Enum class used to enforce list of choices""" + return StrEnum("Choices", {k: k for k in choices}) + + +def relu_squared(x: torch.Tensor): + return F.relu(x).pow(2) + + +def get_activation_fn(activation: str) -> Callable: + """Returns the activation function corresponding to `activation`""" + + def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 + * x + * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + if activation == "relu": + return F.relu + elif activation == "relu_squared": + return relu_squared + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "swish": + return torch.nn.SiLU + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def get_available_activation_fns() -> List: + return [ + "relu", + "gelu", + "gelu_fast", # deprecated + "gelu_accurate", + "tanh", + "linear", + ] + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len and require_same_masks: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + if mask_dropout > 0: + num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) + mask_idc = np.random.choice( + mask_idc, len(mask_idc) - num_holes, replace=False + ) + + mask[i, mask_idc] = True + + return mask + + +def index_put(tensor, indices, value): + tensor[indices] = value + return tensor + + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +def pad_to_multiple(x, multiple, dim=-1, value=0): + # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 + if x is None: + return None, 0 + tsz = x.size(dim) + m = tsz / multiple + remainder = math.ceil(m) * multiple - tsz + if m.is_integer(): + return x, 0 + pad_offset = (0,) * (-1 - dim) * 2 + + return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder + + +EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) +LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer"]) + + +@dataclass +class Wav2Vec2Config: + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group norm with d " + "groups in the first conv block, whereas layer_norm has layer norms in " + "every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + # dropouts + dropout: float = field( + default=0.1, metadata={"help": "dropout probability for the transformer"} + ) + attention_dropout: float = field( + default=0.1, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN"} + ) + encoder_layerdrop: float = field( + default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"} + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many dimensions." + "set to encoder_embed_dim is <= 0" + }, + ) + layer_norm_first: bool = field( + default=False, metadata={"help": "apply layernorm first in the transformer"} + ) + conv_feature_layers: str = field( + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + metadata={ + "help": "string describing convolutional feature extraction layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + quantize_targets: bool = field( + default=False, metadata={"help": "use quantized targets"} + ) + quantize_input: bool = field( + default=False, metadata={"help": "use quantized inputs"} + ) + same_quantizer: bool = field( + default=False, metadata={"help": "use same quantizer for inputs and targets"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, metadata={"help": "multiply feature extractor var grads by this"} + ) + quantizer_depth: int = field( + default=1, + metadata={"help": "number of quantizer layers"}, + ) + quantizer_factor: int = field( + default=3, + metadata={ + "help": "dimensionality increase for inner quantizer layers (if depth > 1)" + }, + ) + latent_vars: int = field( + default=320, + metadata={"help": "number of latent variables V in each group of the codebook"}, + ) + latent_groups: int = field( + default=2, + metadata={"help": "number of groups G of latent variables in the codebook"}, + ) + latent_dim: int = field( + default=0, + metadata={ + "help": "if > 0, uses this dimensionality for latent variables. " + "otherwise uses final_dim / latent_groups" + }, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, metadata={"help": "probability of replacing a token with mask"} + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + require_same_masks: bool = field( + default=True, + metadata={ + "help": "whether to number of masked timesteps must be the same across all " + "examples in a batch" + }, + ) + mask_dropout: float = field( + default=0.0, + metadata={"help": "percent of masks to unmask for each sample"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, metadata={"help": "length of the mask for features (channels)"} + ) + mask_channel_prob: float = field( + default=0.0, metadata={"help": "probability of replacing a feature with 0"} + ) + mask_channel_before: bool = False + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, metadata={"help": "whether to allow channel masks to overlap"} + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # negative selection + num_negatives: int = field( + default=100, + metadata={"help": "number of negative examples from the same sample"}, + ) + negatives_from_everywhere: bool = field( + default=False, + metadata={"help": "sample negatives from everywhere, not just masked states"}, + ) + cross_sample_negatives: int = field( + default=0, metadata={"help": "number of negative examples from the any sample"} + ) + codebook_negatives: int = field( + default=0, metadata={"help": "number of negative examples codebook"} + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + pos_conv_depth: int = field( + default=1, + metadata={"help": "depth of positional encoder network"}, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={ + "help": "temperature for latent variable sampling. " + "can be tuple of 3 values (start, end, decay)" + }, + ) + max_positions: int = field(default=100000, metadata={"help": "Max positions"}) + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + + # FP16 optimization + required_seq_len_multiple: int = field( + default=2, + metadata={ + "help": "pad the input to encoder such that the sequence length is divisible by multiple" + }, + ) + crop_seq_to_multiple: int = field( + default=1, + metadata={ + "help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple" + }, + ) + + # Conformer + depthwise_conv_kernel_size: int = field( + default=31, + metadata={ + "help": "depthwise-conv-kernel-size for convolution in conformer layer" + }, + ) + attn_type: str = field( + default="", + metadata={"help": "if espnet use ESPNET MHA"}, + ) + pos_enc_type: str = field( + default="abs", + metadata={"help": "Positional encoding type to use in conformer"}, + ) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) + + +class Wav2Vec2Model(nn.Module): + def __init__(self, cfg: Wav2Vec2Config): + super().__init__() + self.cfg = cfg + + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input + else None + ) + + self.crop_seq_to_multiple = cfg.crop_seq_to_multiple + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_before = cfg.mask_channel_before + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.quantizer = None + self.input_quantizer = None + + self.n_negatives = cfg.num_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere + + self.logit_temp = cfg.logit_temp + + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + if cfg.quantize_targets: + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + weight_proj_depth=cfg.quantizer_depth, + weight_proj_factor=cfg.quantizer_factor, + ) + self.project_q = nn.Linear(vq_dim, final_dim) + else: + self.project_q = nn.Linear(self.embed, final_dim) + + if cfg.quantize_input: + if cfg.same_quantizer and self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + weight_proj_depth=cfg.quantizer_depth, + weight_proj_factor=cfg.quantizer_factor, + ) + self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + encoder_cls = TransformerEncoder + if cfg.layer_type == "conformer" and cfg.pos_enc_type in ["rel_pos", "rope"]: + encoder_cls = ConformerEncoder + + self.encoder = encoder_cls(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + @classmethod + def build_model(cls, cfg: Wav2Vec2Config, task=None): + """Build a new model instance.""" + return cls(cfg) + + def apply_mask( + self, + x, + padding_mask, + mask_indices=None, + mask_channel_indices=None, + ): + B, T, C = x.shape + + if self.mask_channel_prob > 0 and self.mask_channel_before: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + if self.mask_prob > 0: + if mask_indices is None: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + require_same_masks=self.cfg.require_same_masks, + mask_dropout=self.cfg.mask_dropout, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) + else: + mask_indices = None + + if self.mask_channel_prob > 0 and not self.mask_channel_before: + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel_indices, 0) + + return x, mask_indices + + def sample_negatives(self, y, num, padding_count=None): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + # FIXME: what happens if padding_count is specified? + cross_high = tsz * bsz + high = tsz - (padding_count or 0) + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + neg_idxs = torch.randint( + low=0, high=high - 1, size=(bsz, self.n_negatives * num) + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high) + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view( + bsz, num, self.n_negatives + self.cross_sample_negatives, fsz + ).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def compute_preds(self, x, y, negatives): + + neg_is_pos = (y == negatives).all(-1) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1) + logits = logits / self.logit_temp + logits = logits.type_as(x) + + return logits + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + conv_cfg_list = eval(self.cfg.conv_feature_layers) + + for i in range(len(conv_cfg_list)): + input_lengths = _conv_out_length( + input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] + ) + + return input_lengths.to(torch.long) + + def forward( + self, + source, + padding_mask=None, + mask=True, + features_only=False, + layer=None, + mask_indices=None, + mask_channel_indices=None, + padding_count=None, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None and padding_mask.any(): + input_lengths = (1 - padding_mask.long()).sum(-1) + # apply conv formula to get real output_lengths + output_lengths = self._get_feat_extract_output_lengths(input_lengths) + + padding_mask = torch.zeros( + features.shape[:2], dtype=features.dtype, device=features.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[ + ( + torch.arange(padding_mask.shape[0], device=padding_mask.device), + output_lengths - 1, + ) + ] = 1 + padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() + else: + padding_mask = None + + time_steps_to_drop = features.size(1) % self.crop_seq_to_multiple + if time_steps_to_drop != 0: + features = features[:, :-time_steps_to_drop] + unmasked_features = unmasked_features[:, :-time_steps_to_drop] + if padding_mask is not None: + padding_mask = padding_mask[:, :-time_steps_to_drop] + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + num_vars = None + code_ppl = None + prob_ppl = None + curr_temp = None + + if self.input_quantizer: + q = self.input_quantizer(features, produce_targets=False) + features = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + features = self.project_inp(features) + + if mask: + x, mask_indices = self.apply_mask( + features, + padding_mask, + mask_indices=mask_indices, + mask_channel_indices=mask_channel_indices, + ) + if mask_indices is not None: + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + x = features + y = unmasked_features + mask_indices = None + + x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer) + + if features_only: + return { + "x": x, + "padding_mask": padding_mask, + "features": unmasked_features, + "layer_results": layer_results, + } + + if self.quantizer: + if self.negatives_from_everywhere: + q = self.quantizer(unmasked_features, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + y = self.project_q(y) + + negs, _ = self.sample_negatives( + y, + mask_indices[0].sum(), + padding_count=padding_count, + ) + y = y[mask_indices].view(y.size(0), -1, y.size(-1)) + + else: + q = self.quantizer(y, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + + y = self.project_q(y) + + negs, _ = self.sample_negatives( + y, + y.size(1), + padding_count=padding_count, + ) + + if self.codebook_negatives > 0: + cb_negs = self.quantizer.sample_from_codebook( + y.size(0) * y.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, y.size(0), y.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + negs = torch.cat([negs, cb_negs], dim=0) + else: + y = self.project_q(y) + + if self.negatives_from_everywhere: + negs, _ = self.sample_negatives( + unmasked_features, + y.size(1), + padding_count=padding_count, + ) + negs = self.project_q(negs) + else: + negs, _ = self.sample_negatives( + y, + y.size(1), + padding_count=padding_count, + ) + + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + + x = self.final_proj(x) + x = self.compute_preds(x, y, negs) + + result = { + "x": x, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + + if prob_ppl is not None: + result["prob_perplexity"] = prob_ppl + result["code_perplexity"] = code_ppl + result["num_vars"] = num_vars + result["temp"] = curr_temp + + return result + + def quantize(self, x): + assert self.quantizer is not None + x = self.feature_extractor(x) + x = x.transpose(1, 2) + x = self.layer_norm(x) + return self.quantizer.forward_idx(x) + + def extract_features(self, source, padding_mask, mask=False, layer=None): + res = self.forward( + source, padding_mask, mask=mask, features_only=True, layer=layer + ) + return res + + def get_logits(self, net_output): + logits = net_output["x"] + logits = logits.transpose(0, 2) + logits = logits.reshape(-1, logits.size(-1)) + return logits + + def get_targets(self, sample, net_output, expand_steps=True): + x = net_output["x"] + return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) + + def get_extra_losses(self, net_output): + pen = [] + + if "prob_perplexity" in net_output: + pen.append( + (net_output["num_vars"] - net_output["prob_perplexity"]) + / net_output["num_vars"] + ) + + if "features_pen" in net_output: + pen.append(net_output["features_pen"]) + + return pen + + def remove_pretraining_modules(self, last_layer=None): + self.quantizer = None + self.project_q = None + self.target_glu = None + self.final_proj = None + + if last_layer is not None: + self.encoder.layers = nn.ModuleList( + l for i, l in enumerate(self.encoder.layers) if i <= last_layer + ) + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x + + +def make_conv_pos(e, k, g, is_batch_norm=False): + pos_conv = nn.Conv1d( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) + nn.init.normal_(pos_conv.weight, mean=0, std=std) + nn.init.constant_(pos_conv.bias, 0) + + if not is_batch_norm: + pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2) + pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) + else: + batch_norm = nn.BatchNorm1d(e) + pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU()) + + return pos_conv + + +class TransformerEncoder(nn.Module): + def build_encoder_layer(self, args: Wav2Vec2Config): + if args.layer_type == "transformer": + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + elif args.layer_type == "conformer": + layer = ConformerWav2Vec2EncoderLayer( + embed_dim=self.embedding_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, + activation_fn="swish", + attn_type=args.attn_type, + use_fp16=args.fp16, + pos_enc_type="abs", + ) + return layer + + def __init__(self, args: Wav2Vec2Config): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.required_seq_len_multiple = args.required_seq_len_multiple + + pos_conv_depth = getattr(args, "pos_conv_depth", 1) + if pos_conv_depth > 1: + num_layers = args.pos_conv_depth + k = max(3, args.conv_pos // num_layers) + + def make_conv_block(e, k, g, l): + return nn.Sequential( + *[ + nn.Sequential( + nn.Conv1d( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ), + SamePad(k), + TransposeLast(), + LayerNorm(e, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ) + for _ in range(l) + ] + ) + + self.pos_conv = make_conv_block( + self.embedding_dim, k, args.conv_pos_groups, num_layers + ) + + else: + self.pos_conv = make_conv_pos( + self.embedding_dim, + args.conv_pos, + args.conv_pos_groups, + is_batch_norm=( + args.conv_pos_batch_norm + if hasattr(args, "conv_pos_batch_norm") + else False + ), + ) + + self.layers = nn.ModuleList( + [self.build_encoder_layer(args) for _ in range(args.encoder_layers)] + ) + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features( + self, + x, + padding_mask=None, + tgt_layer=None, + min_layer=0, + ): + + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + # pad to the sequence length dimension + x, pad_length = pad_to_multiple( + x, self.required_seq_len_multiple, dim=-2, value=0 + ) + if pad_length > 0 and padding_mask is None: + padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) + padding_mask[:, -pad_length:] = True + else: + padding_mask, _ = pad_to_multiple( + padding_mask, self.required_seq_len_multiple, dim=-1, value=True + ) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + r = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() if self.layerdrop > 0 else 1 + if not self.training or (dropout_probability > self.layerdrop): + x, (z, lr) = layer( + x, self_attn_padding_mask=padding_mask, need_weights=False + ) + if i >= min_layer: + layer_results.append((x, z, lr)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + # undo paddding + if pad_length > 0: + x = x[:, :-pad_length] + + def undo_pad(a, b, c): + return ( + a[:-pad_length], + b[:-pad_length] if b is not None else b, + c[:-pad_length], + ) + + layer_results = [undo_pad(*u) for u in layer_results] + + return x, layer_results + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + +class ConformerEncoder(TransformerEncoder): + def build_encoder_layer(self, args): + layer = ConformerWav2Vec2EncoderLayer( + embed_dim=self.embedding_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, + activation_fn="swish", + attn_type=args.attn_type, + pos_enc_type=args.pos_enc_type, + use_fp16=args.fp16, # only used for rope + ) + return layer + + def __init__(self, args): + super().__init__(args) + self.args = args + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.pos_enc_type = args.pos_enc_type + max_source_positions = self.max_positions() + + if self.pos_enc_type == "rel_pos": + self.embed_positions = RelPositionalEncoding( + max_source_positions, self.embedding_dim + ) + elif self.pos_enc_type == "rope": + self.embed_positions = None + else: + raise Exception("Unsupported positional encoding type") + + self.layers = nn.ModuleList( + [self.build_encoder_layer(args) for _ in range(args.encoder_layers)] + ) + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # B X T X C here + position_emb = None + if self.pos_enc_type == "rel_pos": + position_emb = self.embed_positions(x) + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + layer_results = [] + r = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + position_emb=position_emb, + ) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask, + need_weights=False, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, (attn, layer_result) + + +@dataclass +class AudioPretrainingConfig: + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, metadata={"help": "pad shorter samples instead of cropping"} + ) + max_sample_size: Optional[int] = field( + default=None, metadata={"help": "max sample size to crop to for batching"} + ) + min_sample_size: Optional[int] = field( + default=None, metadata={"help": "min sample size to skip small examples"} + ) diff --git a/spiritlm/speech_tokenizer/hubert/hubert_tokenizer.py b/spiritlm/speech_tokenizer/hubert/hubert_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3d86070059a20134b795d59d1d69e4232ad4af --- /dev/null +++ b/spiritlm/speech_tokenizer/hubert/hubert_tokenizer.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torchaudio +from torch import nn + +from .hubert_model import load_hubert_model +from .quantizer_model import load_quantizer_model + + +class HubertTokenizer(nn.Module): + def __init__( + self, + hubert_ckpt, + hubert_layer, + quantizer_ckpt, + is_linear_quantizer=True, + min_chunk=400, + max_chunk=100 * 16_000, + ): + super().__init__() + + # hubert model + self.hubert_ckpt = str(hubert_ckpt) + self.hubert_layer = hubert_layer + self.hubert_model = None + self.should_normalize = False + self.min_chunk = min_chunk + self.max_chunk = max_chunk + + # quantizer model + self.quantizer_ckpt = str(quantizer_ckpt) + self.is_linear_quantizer = is_linear_quantizer + self.quantizer_model = None + + # this is useful for determining the device + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) + self.load_models() + + @torch.no_grad() # otherwise some non-leaf nodes appear which breaks serialization + def load_models(self): + # Load hubert model + hubert_model, model_cfg, task_cfg = load_hubert_model(self.hubert_ckpt) + self.hubert_task_cfg = task_cfg + self.hubert_model_cfg = model_cfg + self.hubert_model = hubert_model + self.hubert_model.to(self.device) + self.hubert_model.eval() + for parameter in self.hubert_model.parameters(): + parameter.requires_grad_(False) + self.should_normalize = task_cfg.normalize + + # Load quantizer model + self.quantizer_model = load_quantizer_model( + self.quantizer_ckpt, is_linear_quantizer=self.is_linear_quantizer + ) + self.quantizer_model.to(self.device) + self.quantizer_model.eval() + + @property + def device(self): + return self._float_tensor.device + + @property + def code_hop_size(self) -> int: + hop_size = 1 + for dim, kernel, stride in eval(self.hubert_model_cfg.conv_feature_layers): + hop_size *= stride + return hop_size # 320 for 50hz model and 640 for 25hz model + + @property + def frame_rate(self) -> int: + return self.expected_sample_rate / self.code_hop_size # 50 or 25 + + @property + def n_units(self) -> int: + return self.kmeans_model.K + + @property + def expected_sample_rate(self) -> int: + return self.hubert_task_cfg.sample_rate # 16_000 + + def load_audio(self, path): + wav, sr = torchaudio.load(path) + if sr != self.expected_sample_rate: + wav = torchaudio.functional.resample( + wav, orig_freq=sr, new_freq=self.expected_sample_rate + ) + return wav + + @torch.inference_mode() + def forward(self, x, separate_channels=False, dense=False): + if isinstance(x, str): + x = self.load_audio(x) + i_ndim = x.dim() + if i_ndim == 2: + x = x.unsqueeze(0) + elif i_ndim == 1: + x = x.view(1, 1, -1) + + # x should expect a shape [B, C, T], where C is number of channels + assert len(x.shape) == 3 + feats = self.get_dense_features(x) # [B, T_enc] + + if dense: + return feats + + tokens = self.quantizer_model(feats) # [B, T_enc] + + if i_ndim == 3: + tokens = tokens.view(x.shape[0], 1, -1) + else: + tokens = tokens.squeeze(0) + + if not separate_channels: + return tokens + + @torch.inference_mode() + def get_dense_features(self, x, separate_channels=False): + x = x.to(self.device) + + assert separate_channels == False, "Not supported yet" # TODO: Fix this + + if not separate_channels: + x = x.mean(1) # [B, T] + + if self.should_normalize: + x = torch.cat([nn.functional.layer_norm(item, item.shape) for item in x]) + + feat = [] + for start in range(0, x.size(1), self.max_chunk): + x_chunk = x[:, start : start + self.max_chunk] + if x_chunk.size(1) < self.min_chunk: + continue + feat_chunk, _ = self.hubert_model.extract_features( + source=x_chunk, + padding_mask=None, + mask=False, + output_layer=self.hubert_layer, + ) + feat.append(feat_chunk) + + return torch.cat(feat, 1) diff --git a/spiritlm/speech_tokenizer/hubert/quantizer_model.py b/spiritlm/speech_tokenizer/hubert/quantizer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6b38c1ba1d7987c80ea3033db190a1814614c5ee --- /dev/null +++ b/spiritlm/speech_tokenizer/hubert/quantizer_model.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torch import nn + +_logger = logging.getLogger(__name__) + + +class LinearQuantizerModel(nn.Module): + def __init__(self, ckpt_path): + super().__init__() + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + + self.vocab_size = state_dict["model_cfg"]["vocab_size"] + dim = state_dict["model_cfg"]["dim"] + upstream_dim = state_dict["model_cfg"]["upstream_dim"] + + out_dim = self.vocab_size + 1 # vocab_size + 1 for blank in CTC + mid_dim = upstream_dim - out_dim + + self.encoder = nn.Sequential( + *[ + nn.Linear(dim, dim - mid_dim // 4), + nn.LeakyReLU(), + nn.Linear(dim - mid_dim // 4, dim - mid_dim // 2), + nn.LeakyReLU(), + nn.Linear(dim - mid_dim // 2, self.vocab_size + 1), + ] + ) + + self.encoder.load_state_dict(state_dict["model_weight"]) + + def forward(self, x): + logits = self.encoder(x) + logits = torch.nn.functional.log_softmax(logits, dim=-1) + code = logits.argmax(dim=-1) + + # post-process units: replace BLANK with most-left non-BLANK units + non_stop_counter = 0 + while (code == self.vocab_size).any(): + non_stop_counter += 1 + code[code == self.vocab_size] = torch.roll(code, 1)[code == self.vocab_size] + if non_stop_counter == 10000: + break + + return code + + +class KmeansModel(nn.Module): + def __init__(self, km_path): + super().__init__() + states = torch.load(km_path, map_location="cpu", weights_only=True) + assert ( + "cluster_centers" in states and "n_clusters" in states + ), "Not a valid kmeans checkpoint." + C_np = states["cluster_centers"].transpose() # [d_feats, K] + Cnorm_np = (C_np**2).sum(0, keepdims=True) # [K,] + self.K = states["n_clusters"] + assert self.K == C_np.shape[-1] + + self.C = nn.Parameter(torch.from_numpy(C_np), requires_grad=False) + self.Cnorm = nn.Parameter(torch.from_numpy(Cnorm_np), requires_grad=False) + + def forward(self, x): + batched = False + if len(x.shape) == 3: # [B, T, d] + batched = True + B, T, d = x.shape + x = x.view(-1, d) + + # x: [T, d]; C: [d, K]; Cnorm: [K,] + dist = x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm + assigned_clusters = dist.argmin(dim=1) # [T,] + + if batched: + assigned_clusters = assigned_clusters.view(B, T) + + return assigned_clusters + + +def load_quantizer_model(ckpt_path, is_linear_quantizer): + if is_linear_quantizer: + model = LinearQuantizerModel(ckpt_path) + _logger.info(f"Loaded LinearQuantizer from '{ckpt_path}'") + else: + model = KmeansModel(ckpt_path) + _logger.info(f"Loaded KmeansModel from '{ckpt_path}'") + return model diff --git a/spiritlm/speech_tokenizer/spiritlm_tokenizer.py b/spiritlm/speech_tokenizer/spiritlm_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f25f32c7f8089eda76d2a12f3cabeedc7bc4b13 --- /dev/null +++ b/spiritlm/speech_tokenizer/spiritlm_tokenizer.py @@ -0,0 +1,361 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +from typing import Dict, List + +import torchaudio + +MOST_COMMON_STYLES = [71, 68, 98] + + +_logger = logging.getLogger(__name__) + + +def _toks_positions(toks: List[str], rate: float, dedup: bool): + prev_tok = None + res = [] + for i, tok in enumerate(toks): + if (not dedup) or (prev_tok is None or tok != prev_tok): + res += [(tok, i / rate)] + prev_tok = tok + return res + + +def units_to_string( + units: Dict[str, str], + has_pitch=False, + has_style=False, + hubert_rate=24.99, + hubert_dedup=True, + hubert_key="hubert", + pitch_rate=12.5, + pitch_dedup=True, + pitch_key="pitch", + style_rate=1, + style_dedup=False, + style_key="style", +) -> str: + """ + Example: + - input (units): + { + 'hubert': '78 42 81 159 316 259', + 'pitch': '13 13 13 13 13 3', + 'style': '81 81 81 81 81 81', + } + - output: + '[St81][Hu78][Pi13][Hu42][Hu81][Hu159][Hu316][Pi3][Hu259]' + """ + + combine_toks = [] + + if has_style: + combine_toks += _toks_positions( + [f"[St{i}]" for i in units[style_key].split()], style_rate, style_dedup + ) + if has_pitch: + combine_toks += _toks_positions( + [f"[Pi{i}]" for i in units[pitch_key].split()], pitch_rate, pitch_dedup + ) + combine_toks += _toks_positions( + [f"[Hu{i}]" for i in units[hubert_key].split()], hubert_rate, hubert_dedup + ) + combine_toks = [tok_pos[0] for tok_pos in sorted(combine_toks, key=lambda x: x[1])] + return "".join(combine_toks) + + +def get_random_most_common_style() -> int: + return random.choice(MOST_COMMON_STYLES) + + +def string_to_units( + gen, + hubert_key="hubert", + pitch_key="pitch", + style_key="style", + duplicate_hubert_for_multiple_pitch=False, +): + """ + Convert from tokenized string to dictionary of units. + The units are 'pre-duplicated' to match the number of hubert units. + Examples + - input: + '[St81][Hu78][Pi13][Hu42][Hu81][Hu159][Hu316][Pi3][Hu259]' + - output: + { + 'hubert': '78 42 81 159 316 259', + 'pitch': '13 13 13 13 13 3', + 'style': '81 81 81 81 81 81', + } + """ + prev_hubert = None + first_hubert = None + prev_pitch = None + first_pitch = None + prev_style = None + first_style = None + prev_is_pitch = False # If this is True, add prev_hubert to the codes + hubert = [] + pitch = [] + style = [] + for item in gen.split("["): + if item and len(item) > 2: + if item.startswith("Hu") and item[2].isdigit(): + hubert += [item[2:-1]] + pitch += [prev_pitch] + style += [prev_style] + prev_is_pitch = False + prev_hubert = item[2:-1] + if first_hubert is None: + first_hubert = item[2:-1] + elif item.startswith("St") and item[2].isdigit(): + if prev_style is None: + first_style = item[2:-1] + prev_style = item[2:-1] + elif item.startswith("Pi") and item[2].isdigit(): + if duplicate_hubert_for_multiple_pitch and prev_is_pitch: + hubert += [prev_hubert] + pitch += [item[2:-1]] + style += [prev_style] + if prev_pitch is None: + first_pitch = item[2:-1] + prev_pitch = item[2:-1] + prev_is_pitch = True + if first_pitch is not None and first_style is None: + # in rare case, style is not present, we select randomly a common style token to make decoding work + first_style = str(get_random_most_common_style()) + for i in range(len(hubert)): + if hubert[i] is None: + hubert[i] = first_hubert + if style[i] is None: + style[i] = first_style + if pitch[i] is None: + pitch[i] = first_pitch + units = {hubert_key: " ".join(hubert)} + if first_pitch is not None: + units[pitch_key] = " ".join(pitch) + if first_style is not None: + units[style_key] = " ".join(style) + return units + + +class SpiritLMTokenizer: + def __init__( + self, + hubert_model, + pitch_model=None, + style_model=None, + hifigan_model=None, + hubert_rate=24.99, + hubert_dedup=True, + hubert_key="hubert", + pitch_rate=12.5, + pitch_dedup=True, + pitch_key="pitch", + style_rate=1, + style_dedup=False, + style_key="style", + expected_sample_rate=16_000, + max_wav_chunk=100 * 16_000, + min_wav_chunk=1280, # 400 is minimum for hubert, 1280 (80ms) is minimum for pitch, so let's take 1280 + ): + super().__init__() + + self.hubert_model = hubert_model + self.pitch_model = pitch_model + self.style_model = style_model + self.hifigan_model = hifigan_model + + self.hubert_rate = hubert_rate + self.hubert_dedup = hubert_dedup + self.hubert_key = hubert_key + + self.speech_token = "[Speech]" + self.pitch_key = None + self.style_key = None + if pitch_model is not None: + self.pitch_rate = pitch_rate + self.pitch_dedup = pitch_dedup + self.pitch_key = pitch_key + if style_model is not None: + self.style_rate = style_rate + self.style_dedup = style_dedup + self.style_key = style_key + + self.expected_sample_rate = expected_sample_rate + self.max_wav_chunk = max_wav_chunk + self.min_wav_chunk = min_wav_chunk + + def load_audio(self, path): + wav, sr = torchaudio.load(path) + if sr != self.expected_sample_rate: + wav = torchaudio.functional.resample( + wav, orig_freq=sr, new_freq=self.expected_sample_rate + ) + return wav + + def encode_units(self, audio, channel_id=None): + """ + Get the speech units in dictionary format, e.g. + { + 'audio': 'path/to/audio.wav', + 'hubert': '1 1 2 2 3', + 'pitch': '15 15 20', + 'style': '7', + } + The audio can be the path to audio file or an array. + For stereo audio file, channel_id can be set (0 or 1). + """ + units = {} + + if isinstance(audio, str): + units["audio"] = os.path.abspath(audio) + audio = self.load_audio(audio) + audio = audio.squeeze() + if len(audio.shape) == 2: + assert ( + audio.shape[0] == 2 + ), f"expected a stereo wav of shape (2,x), found {audio.shape}" + if channel_id is None: + _logger.warning( + "Found stereo audio input, averaging audio from 2 channels. If you want to extract" + "only one channel, set channel_id to 0 or 1" + ) + audio = audio.mean(0) + else: + audio = audio[channel_id] + assert len(audio.shape) == 1, audio.shape + + hubert_units = [] + pitch_units = [] + style_units = [] + for start in range(0, len(audio), self.max_wav_chunk): + audio_chunk = audio[start : start + self.max_wav_chunk] + if len(audio_chunk) < self.min_wav_chunk: + continue + hubert_units.extend([str(i.item()) for i in self.hubert_model(audio_chunk)]) + if self.pitch_model is not None: + pitch_units.extend( + [str(i.item()) for i in self.pitch_model(audio_chunk)] + ) + if self.style_model is not None: + style_units.extend( + [str(i.item()) for i in self.style_model(audio_chunk)] + ) + + units[self.hubert_key] = " ".join(hubert_units) + if self.pitch_model is not None: + units[self.pitch_key] = " ".join(pitch_units) + if self.style_model is not None: + units[self.style_key] = " ".join(style_units) + return units + + def units2string(self, units): + """ + Convert from dictionary of units to tokenized string. + The units are (optionally deduped) sorted by time steps and interleaved + """ + has_pitch = self.pitch_model is not None + has_style = self.style_model is not None + return units_to_string( + units=units, + has_pitch=has_pitch, + has_style=has_style, + hubert_rate=self.hubert_rate, + hubert_dedup=self.hubert_dedup, + hubert_key=self.hubert_key, + pitch_rate=self.pitch_rate if has_pitch else None, + pitch_dedup=self.pitch_dedup if has_pitch else None, + pitch_key=self.pitch_key if has_pitch else None, + style_rate=self.style_rate if has_style else None, + style_dedup=self.style_dedup if has_style else None, + style_key=self.style_key if has_style else None, + ) + + def encode_string(self, audio): + """ + Tokenize the audio into string format, e.g. + '[St7][Pi15][Hu1][Hu2][Pi20][Hu3]' + """ + units = self.encode_units(audio) + return self.units2string(units) + + def __call__(self, audio): + """ + Default call method + """ + return self.encode_string(audio) + + def string2units(self, gen, duplicate_hubert_for_multiple_pitch=False): + """ + Convert from tokenized string to dictionary of units. + The units are 'pre-duplicated' to match the number of hubert units. + Examples + - input: + '[St81][Hu78][Pi13][Hu42][Hu81][Hu159][Hu316][Pi3][Hu259]' + - output: + { + 'hubert': '78 42 81 159 316 259', + 'pitch': '13 13 13 13 13 3', + 'style': '81 81 81 81 81 81', + } + """ + return string_to_units( + gen, + hubert_key=self.hubert_key, + pitch_key=self.pitch_key if self.pitch_key else "pitch", + style_key=self.style_key if self.style_key else "style", + duplicate_hubert_for_multiple_pitch=duplicate_hubert_for_multiple_pitch, + ) + + def decode(self, code, speaker_id=2, dur_pred=True): + """ + code can be under text form ([Hu1][Hu2]) or units form ({'hubert': '1 2'}) + """ + + assert self.hifigan_model is not None + + if isinstance(code, str): + units = self.string2units(code) + else: + units = code + + # if units['hubert'] doesn't have the same number as units['f0'] + # then likely this is resynthesis task, and we'll set dur_pred=False + if ( + self.pitch_key + and self.pitch_key in units + and len(units[self.pitch_key].split()) + != len(units[self.hubert_key].split()) + ): + dur_pred = False + + wav = ( + self.hifigan_model( + code=units[self.hubert_key], + f0_code=( + units[self.pitch_key] + if self.pitch_key and self.pitch_key in units + else None + ), + style_code=( + units[self.style_key] + if self.style_key and self.style_key in units + else None + ), + dur_pred=dur_pred, + speaker_id=speaker_id, + not_dedup_code=True, + ) + .detach() + .cpu() + .numpy() + ) + + return wav diff --git a/spiritlm/speech_tokenizer/style_encoder/__init__.py b/spiritlm/speech_tokenizer/style_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..204aaf9b4e50b2285e37954d7be326dc007938da --- /dev/null +++ b/spiritlm/speech_tokenizer/style_encoder/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + + +import logging +import os +from pathlib import Path + +import torch + +from .w2v2_encoder import Wav2Vec2StyleEncoder + +_logger = logging.getLogger(__name__) + +# Get the base checkpoints directory from environment variable or use the default base path +base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parents[3] / "checkpoints")) + +# Append 'speech_tokenizer' to the base path +CHECKPOINT_DIR = base_checkpoints_dir / "speech_tokenizer" + +CURRENT_DEVICE = ( + torch.device(torch.cuda.current_device()) + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" +) + + +def spiritlm_expressive_style_encoder_w2v2() -> Wav2Vec2StyleEncoder: + STYLE_ENCODER_CKPT_PATH = CHECKPOINT_DIR / "style_encoder_w2v2" + model = Wav2Vec2StyleEncoder.from_pretrained( + pretrained_model_name_or_path=STYLE_ENCODER_CKPT_PATH + ).to(CURRENT_DEVICE) + _logger.info(f"Style encoder loaded from {str(STYLE_ENCODER_CKPT_PATH)}") + return model diff --git a/spiritlm/speech_tokenizer/style_encoder/__pycache__/__init__.cpython-310.pyc b/spiritlm/speech_tokenizer/style_encoder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5d9760efc28229f38fbd9811d1b1cb1f6d195c5 Binary files /dev/null and b/spiritlm/speech_tokenizer/style_encoder/__pycache__/__init__.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/style_encoder/__pycache__/w2v2_encoder.cpython-310.pyc b/spiritlm/speech_tokenizer/style_encoder/__pycache__/w2v2_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cde9e447d5c285c2e39eea0170c501d6d7be33b Binary files /dev/null and b/spiritlm/speech_tokenizer/style_encoder/__pycache__/w2v2_encoder.cpython-310.pyc differ diff --git a/spiritlm/speech_tokenizer/style_encoder/w2v2_encoder.py b/spiritlm/speech_tokenizer/style_encoder/w2v2_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c02b483e56b5fc0db670ebdf785317b771fef303 --- /dev/null +++ b/spiritlm/speech_tokenizer/style_encoder/w2v2_encoder.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from typing import Union + +import torch +import torchaudio +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification + + +class Wav2Vec2StyleEncoder(Wav2Vec2ForSequenceClassification): + def __init__(self, config, pool_size: int = 50): + super().__init__(config) + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "facebook/wav2vec2-base" + ) + self.pool_size = pool_size + + # this is useful for determining the device + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) + + @property + def device(self): + return self._float_tensor.device + + @torch.no_grad() + def forward(self, wavs: Union[torch.Tensor, str]) -> torch.Tensor: + if isinstance(wavs, str): + # TODO: resampling if applicable + wavs = torchaudio.load(wavs)[0].squeeze(0) + # TODO: handle list of strs + inputs = self.feature_extractor( + wavs, sampling_rate=16_000, return_tensors="pt" + ).input_values + outputs = self.wav2vec2(inputs.to(self.device)) + hidden_states = outputs[0] + hidden_states = self.projector(hidden_states) + chunk_size = self.pool_size + batch_size, sequence_length, hidden_size = hidden_states.shape + pooled_output = [] + for i in range(0, sequence_length, chunk_size): + chunk = hidden_states[:, i : i + chunk_size, :] + pooled_output.append(chunk.mean(dim=1)) + pooled_output = torch.cat( + pooled_output, dim=1 + ) # Concatenate the chunks along the desired dimension + pooled_output = pooled_output.view( + batch_size, -1, hidden_size + ) # Reshape back to the original shape + logits = self.classifier(pooled_output) + lprobs = torch.nn.functional.log_softmax(logits, dim=-1) + pred = torch.argmax(lprobs, dim=-1).squeeze(0) + return pred diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17edf7fe569a6e89feb6107141c62c96ab684ea5 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. diff --git a/tests/test_spirit_model.py b/tests/test_spirit_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e245b8b3805dd395a8d28e0a2725548ee581409b --- /dev/null +++ b/tests/test_spirit_model.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +from unittest.mock import Mock, patch + +import pytest +from spiritlm.model.spiritlm_model import Spiritlm +from spiritlm.model.utils import ( + does_end_with_speech_token, + does_start_with_speech_token, + find_prompt_last_speech_start_position, +) + + +@pytest.mark.parametrize( + "content,expected", + [ + ( + "abc[Speech][St1][Pi234][Hu123][Hu45][Text]hello world[", + [("abc", "t"), ("[St1][Pi234][Hu123][Hu45]", "s"), ("hello world[", "t")], + ), + ( + "[St1][Pi234][Hu123][Hu45]", + [("[St1][Pi234][Hu123][Hu45]", "s")], + ), + ( + "abc", + [("abc", "t")], + ), + ( + "abc[]", + [("abc[]", "t")], + ), + ( + "[St1][Pi234][Hu123][Hu45][Text][abc", + [("[St1][Pi234][Hu123][Hu45]", "s"), ("[abc", "t")], + ), + ( + "abc[Text]def", + [("abcdef", "t")], + ), + ], +) +def test_parse_speech_and_text(content, expected): + with patch( + "spiritlm.model.spiritlm_model.Spiritlm.__init__", Mock(return_value=None) + ): + mock_spiritlm_model = Spiritlm("spirit-lm-base-7b") + mock_spiritlm_model.speech_prompt_prefix = "[Speech]" + assert mock_spiritlm_model._parse_speech_and_text(content) == expected + + +@pytest.mark.parametrize( + "content,expected", + [ + ( + "[Hu338][Text] and they went out together[Speech][Hu431][Pi0][Hu457][Hu79][Pi11][Hu258][Hu85][Hu28][Hu50][Text] and mrs johnson shoes except in mourning[Speech][Pi59][Hu32][Pi20][Hu453][Pi35][Pi26][Hu166]", + [ + ("[Hu338]", "s"), + (" and they went out together", "t"), + ("[Hu431][Pi0][Hu457][Hu79][Pi11][Hu258][Hu85][Hu28][Hu50]", "s"), + (" and mrs johnson shoes except in mourning", "t"), + ("[Pi59][Hu32][Pi20][Hu453][Pi35][Pi26][Hu166]", "s"), + ], + ) + ], +) +def test_parse_speech_and_text_with_expressive_tokens(content, expected): + with patch( + "spiritlm.model.spiritlm_model.Spiritlm.__init__", Mock(return_value=None) + ): + mock_spiritlm_model = Spiritlm("spirit-lm-base-7b") + mock_spiritlm_model.speech_prompt_prefix = "[Speech]" + print(f"content: {content}") + print(f"expected: {expected}") + assert mock_spiritlm_model._parse_speech_and_text(content) == expected + + +@pytest.mark.parametrize( + "encoded_string,expected", + [ + ( + "]]", + False, + ), + ( + "[]", + False, + ), + ( + "[Hu100]", + True, + ), + ("abc[]", False), + ( + "[St1][Pi234][Hu123][Hu45][Text][abc]", + False, + ), + ( + "abc[Text]def", + False, + ), + ( + "[Pi9]", + True, + ), + ( + "[St0]", + True, + ), + ], +) +def test_does_prompt_end_by_speech(encoded_string, expected): + assert does_end_with_speech_token(encoded_string) == expected + + +@pytest.mark.parametrize( + "encoded_string,expected", + [ + ( + "abc[Hu123][Hu456][Pi23][St2]", + 3, + ), + ( + "[Hu123]abc[Hu123][Hu456][Pi23][St2]", + 10, + ), + ( + "[Hu123][Hu456][Pi23][St2]", + 0, + ), + ( + "abc", + None, + ), + ( + "[Speech][St71][Pi39][Hu99][Hu49][Pi57][Hu38][Hu149][Pi48][Hu71][Hu423][Hu427][Pi56][Hu492][Hu288][Pi44][Hu315][Hu153][Pi42][Hu389][Pi59][Hu497][Hu412][Pi51][Hu247][Hu354][Pi44][Hu7][Hu96][Pi43][Hu452][Pi0][Hu176][Hu266][Pi54][St71][Hu77][Pi13][Hu248][Hu336][Pi39][Hu211][Pi25][Hu166][Hu65][Pi58][Hu94][Hu224][Pi26][Hu148][Pi44][Hu492][Hu191][Pi26][Hu440][Pi13][Hu41][Pi20][Hu457][Hu79][Pi46][Hu382][Hu451][Pi26][Hu332][Hu216][Hu114][Hu340][St71][Pi40][Hu478][Hu74][Pi26][Hu79][Hu370][Pi56][Hu272][Hu370][Pi51][Hu53][Pi14][Hu477][Hu65][Pi46][Hu171][Hu60][Pi41][Hu258][Hu111][Pi40][Hu338][Hu23][Pi39][Hu338][Hu23][Hu338][St71][Pi57][Hu7][Hu338][Hu149][Pi59][Hu406][Hu7][Hu361][Hu99][Pi20][Hu209][Hu479][Pi35][Hu50][St71][Hu7][Hu149][Pi55][Hu35][Pi13][Hu130][Pi3][Hu169][Pi52][Hu72][Pi9][Hu434][Hu119][Hu272][Hu4][Pi20][Hu249][Hu245][Pi57][Hu433][Pi56][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][St71][Hu226][Pi32][Hu370][Hu216][Pi39][Hu459][Hu424][Pi57][Hu226][Pi46][Hu382][Hu7][Pi27][Hu58][Hu138][Pi20][Hu428][Hu397][Pi44][Hu350][Pi32][Hu306][Pi59][Hu84][Hu11][Hu171][Pi42][Hu60][Pi48][Hu314][Hu227][St71][Hu355][Pi56][Hu9][Hu58][Pi44][Hu138][Hu226][Pi25][Hu370][Hu272][Pi56][Hu382][Hu334][Pi26][Hu330][Hu176][Pi56][Hu307][Pi46][Hu145][Hu248][Pi56][Hu493][Hu64][Pi40][Hu44][Hu388][Pi39][Hu7][Hu111][Pi59][St71][Hu23][Hu481][Pi13][Hu149][Pi15][Hu80][Hu70][Pi47][Hu431][Hu457][Pi13][Hu79][Pi27][Hu249][Pi55][Hu245][Pi54][Hu433][Pi36][Hu316][Pi53][Hu180][Pi3][Hu458][Pi26][Hu86][St71][Pi43][Hu225][Pi49][Hu103][Hu60][Pi3][Hu96][Hu119][Pi39][Hu129][Pi41][Hu356][Hu218][Pi14][Hu4][Hu259][Pi56][Hu392][Pi46][Hu490][Hu75][Pi14][Hu488][Hu166][Pi46][Hu65][Hu171][Pi40][Hu60][Hu7][Hu54][Pi39][Hu85][St83][Pi40][Hu361]", + 8, + ), + ], +) +def test_find_prompt_last_speech_start_position(encoded_string, expected): + assert find_prompt_last_speech_start_position(encoded_string) == expected + + +@pytest.mark.parametrize( + "encoded_string,expected", + [ + ( + "[[", + False, + ), + ( + "[]", + False, + ), + ( + "[Hu100]", + True, + ), + ("abc[]", False), + ( + "[St1][Pi234][Hu123][Hu45][Text][abc]", + True, + ), + ( + "abc[Text]def", + False, + ), + ( + "[Pi9]", + True, + ), + ( + "[St0]", + True, + ), + ], +) +def test_does_start_with_speech_token(encoded_string, expected): + assert does_start_with_speech_token(encoded_string) == expected diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..75ece3393d15b8483546e7d6cc4b93bb5c545c83 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the FAIR Noncommercial Research License +# found in the LICENSE file in the root directory of this source tree. + +import pytest +import torchaudio +from spiritlm.speech_tokenizer import spiritlm_base, spiritlm_expressive + + +@pytest.fixture +def spiritlm_expressive_tokenizer(): + return spiritlm_expressive() + + +@pytest.fixture +def spiritlm_base_tokenizer(): + return spiritlm_base() + + +def test_expressive_tokenizer_encode_units(spiritlm_expressive_tokenizer): + audio = "examples/audio/7143-88743-0029.flac" + units = spiritlm_expressive_tokenizer.encode_units(audio) + expected = { + "hubert": "99 49 38 149 149 71 423 427 492 288 315 153 153 389 497 412 247 354 7 96 452 452 176 266 266 77 248 336 336 211 166 65 94 224 224 148 492 191 440 440 41 41 457 79 382 451 332 216 114 340 478 74 79 370 272 370 370 53 477 65 171 60 258 111 111 111 111 338 338 23 23 338 23 338 338 338 7 338 338 149 406 7 361 361 361 99 99 99 99 99 99 99 209 209 209 209 209 479 50 50 7 149 149 35 35 130 130 169 169 72 434 119 272 4 249 245 245 433 159 294 139 359 343 269 302 226 370 216 459 424 424 226 382 7 58 138 428 397 350 350 306 306 306 84 11 171 171 60 314 227 227 355 9 58 138 226 370 272 382 334 330 176 176 307 145 248 493 64 44 388 7 111 111 111 111 23 23 481 149 149 80 70 431 457 79 79 249 249 245 245 245 433 433 316 316 180 458 458 458 86 86 225 103 60 96 119 119 129 356 218 4 259 259 392 490 75 488 166 65 171 60 7 54 54 85 85 361 361", + "pitch": "39 39 39 48 56 40 42 39 51 40 43 54 3 35 39 25 58 26 44 40 13 20 46 41 26 40 26 56 41 46 46 41 41 40 40 40 39 39 57 59 59 59 59 59 59 59 59 20 20 20 35 35 13 3 9 6 0 20 57 56 56 56 56 59 44 57 41 59 42 51 59 57 59 59 39 39 46 56 58 41 41 40 39 39 39 59 59 59 15 27 13 55 13 27 35 36 3 53 3 26 43 53 54 39 25 14 41 46 46 46 46 41 41 41", + "style": "71 71 71 71 71 71 71 71 71 83", + } + for token_key in ["hubert", "pitch", "style"]: + assert ( + expected[token_key] == units[token_key] + ), f"{token_key} expected {expected[token_key]}, got {units[token_key]}" + + +def test_expressive_tokenizer_encode_units_with_tensor_input( + spiritlm_expressive_tokenizer, +): + wav = torchaudio.load("examples/audio/7143-88743-0029.flac")[0].squeeze(0) + units = spiritlm_expressive_tokenizer.encode_units(wav) + expected = { + "hubert": "99 49 38 149 149 71 423 427 492 288 315 153 153 389 497 412 247 354 7 96 452 452 176 266 266 77 248 336 336 211 166 65 94 224 224 148 492 191 440 440 41 41 457 79 382 451 332 216 114 340 478 74 79 370 272 370 370 53 477 65 171 60 258 111 111 111 111 338 338 23 23 338 23 338 338 338 7 338 338 149 406 7 361 361 361 99 99 99 99 99 99 99 209 209 209 209 209 479 50 50 7 149 149 35 35 130 130 169 169 72 434 119 272 4 249 245 245 433 159 294 139 359 343 269 302 226 370 216 459 424 424 226 382 7 58 138 428 397 350 350 306 306 306 84 11 171 171 60 314 227 227 355 9 58 138 226 370 272 382 334 330 176 176 307 145 248 493 64 44 388 7 111 111 111 111 23 23 481 149 149 80 70 431 457 79 79 249 249 245 245 245 433 433 316 316 180 458 458 458 86 86 225 103 60 96 119 119 129 356 218 4 259 259 392 490 75 488 166 65 171 60 7 54 54 85 85 361 361", + "pitch": "39 39 39 48 56 40 42 39 51 40 43 54 3 35 39 25 58 26 44 40 13 20 46 41 26 40 26 56 41 46 46 41 41 40 40 40 39 39 57 59 59 59 59 59 59 59 59 20 20 20 35 35 13 3 9 6 0 20 57 56 56 56 56 59 44 57 41 59 42 51 59 57 59 59 39 39 46 56 58 41 41 40 39 39 39 59 59 59 15 27 13 55 13 27 35 36 3 53 3 26 43 53 54 39 25 14 41 46 46 46 46 41 41 41", + "style": "71 71 71 71 71 71 71 71 71 83", + } + for token_key in ["hubert", "pitch", "style"]: + assert ( + expected[token_key] == units[token_key] + ), f"{token_key} expected {expected[token_key]}, got {units[token_key]}" + + +def test_base_tokenizer_encode_units(spiritlm_base_tokenizer): + audio = "examples/audio/7143-88743-0029.flac" + units = spiritlm_base_tokenizer.encode_units(audio) + expected_hubert = "99 49 38 149 149 71 423 427 492 288 315 153 153 389 497 412 247 354 7 96 452 452 176 266 266 77 248 336 336 211 166 65 94 224 224 148 492 191 440 440 41 41 457 79 382 451 332 216 114 340 478 74 79 370 272 370 370 53 477 65 171 60 258 111 111 111 111 338 338 23 23 338 23 338 338 338 7 338 338 149 406 7 361 361 361 99 99 99 99 99 99 99 209 209 209 209 209 479 50 50 7 149 149 35 35 130 130 169 169 72 434 119 272 4 249 245 245 433 159 294 139 359 343 269 302 226 370 216 459 424 424 226 382 7 58 138 428 397 350 350 306 306 306 84 11 171 171 60 314 227 227 355 9 58 138 226 370 272 382 334 330 176 176 307 145 248 493 64 44 388 7 111 111 111 111 23 23 481 149 149 80 70 431 457 79 79 249 249 245 245 245 433 433 316 316 180 458 458 458 86 86 225 103 60 96 119 119 129 356 218 4 259 259 392 490 75 488 166 65 171 60 7 54 54 85 85 361 361" + assert expected_hubert == units["hubert"] + + +def test_expressive_tokenizer_encode_string(spiritlm_expressive_tokenizer): + audio = "examples/audio/7143-88743-0029.flac" + encoded_string = spiritlm_expressive_tokenizer.encode_string(audio) + expected = "[St71][Pi39][Hu99][Hu49][Hu38][Hu149][Hu71][Pi48][Hu423][Hu427][Pi56][Hu492][Hu288][Pi40][Hu315][Hu153][Pi42][Hu389][Pi39][Hu497][Hu412][Pi51][Hu247][Hu354][Pi40][Hu7][Hu96][Pi43][Hu452][Pi54][Hu176][Hu266][Pi3][St71][Hu77][Pi35][Hu248][Hu336][Pi39][Hu211][Pi25][Hu166][Hu65][Pi58][Hu94][Hu224][Pi26][Hu148][Pi44][Hu492][Hu191][Pi40][Hu440][Pi13][Hu41][Pi20][Hu457][Hu79][Pi46][Hu382][Hu451][Pi41][Hu332][Hu216][Pi26][Hu114][Hu340][St71][Pi40][Hu478][Hu74][Pi26][Hu79][Hu370][Pi56][Hu272][Hu370][Pi41][Hu53][Pi46][Hu477][Hu65][Hu171][Hu60][Pi41][Hu258][Hu111][Pi40][Hu338][Hu23][Hu338][Pi39][Hu23][Hu338][St71][Pi57][Hu7][Hu338][Pi59][Hu149][Hu406][Hu7][Hu361][Hu99][Hu209][Pi20][Hu479][Hu50][St71][Pi35][Hu7][Hu149][Hu35][Pi13][Hu130][Pi3][Hu169][Pi9][Hu72][Pi6][Hu434][Hu119][Pi0][Hu272][Hu4][Pi20][Hu249][Hu245][Pi57][Hu433][Pi56][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][St71][Hu226][Pi59][Hu370][Hu216][Pi44][Hu459][Hu424][Pi57][Hu226][Pi41][Hu382][Hu7][Pi59][Hu58][Hu138][Pi42][Hu428][Hu397][Pi51][Hu350][Pi59][Hu306][Pi57][Hu84][Pi59][Hu11][Hu171][Hu60][Pi39][Hu314][Hu227][St71][Hu355][Pi46][Hu9][Hu58][Pi56][Hu138][Hu226][Pi58][Hu370][Hu272][Pi41][Hu382][Hu334][Hu330][Hu176][Pi40][Hu307][Pi39][Hu145][Hu248][Hu493][Hu64][Hu44][Hu388][Pi59][Hu7][Hu111][St71][Hu23][Pi15][Hu481][Pi27][Hu149][Pi13][Hu80][Hu70][Pi55][Hu431][Hu457][Pi13][Hu79][Pi27][Hu249][Pi35][Hu245][Pi36][Hu433][Pi3][Hu316][Pi53][Hu180][Pi3][Hu458][Pi26][Hu86][St71][Pi43][Hu225][Pi53][Hu103][Hu60][Pi54][Hu96][Hu119][Pi39][Hu129][Pi25][Hu356][Hu218][Pi14][Hu4][Hu259][Pi41][Hu392][Pi46][Hu490][Hu75][Hu488][Hu166][Hu65][Hu171][Hu60][Hu7][Pi41][Hu54][Hu85][St83][Hu361]" + assert encoded_string == expected + + +def test_base_tokenizer_encode_string(spiritlm_base_tokenizer): + audio = "examples/audio/7143-88743-0029.flac" + encoded_string = spiritlm_base_tokenizer.encode_string(audio) + expected = "[Hu99][Hu49][Hu38][Hu149][Hu71][Hu423][Hu427][Hu492][Hu288][Hu315][Hu153][Hu389][Hu497][Hu412][Hu247][Hu354][Hu7][Hu96][Hu452][Hu176][Hu266][Hu77][Hu248][Hu336][Hu211][Hu166][Hu65][Hu94][Hu224][Hu148][Hu492][Hu191][Hu440][Hu41][Hu457][Hu79][Hu382][Hu451][Hu332][Hu216][Hu114][Hu340][Hu478][Hu74][Hu79][Hu370][Hu272][Hu370][Hu53][Hu477][Hu65][Hu171][Hu60][Hu258][Hu111][Hu338][Hu23][Hu338][Hu23][Hu338][Hu7][Hu338][Hu149][Hu406][Hu7][Hu361][Hu99][Hu209][Hu479][Hu50][Hu7][Hu149][Hu35][Hu130][Hu169][Hu72][Hu434][Hu119][Hu272][Hu4][Hu249][Hu245][Hu433][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][Hu226][Hu370][Hu216][Hu459][Hu424][Hu226][Hu382][Hu7][Hu58][Hu138][Hu428][Hu397][Hu350][Hu306][Hu84][Hu11][Hu171][Hu60][Hu314][Hu227][Hu355][Hu9][Hu58][Hu138][Hu226][Hu370][Hu272][Hu382][Hu334][Hu330][Hu176][Hu307][Hu145][Hu248][Hu493][Hu64][Hu44][Hu388][Hu7][Hu111][Hu23][Hu481][Hu149][Hu80][Hu70][Hu431][Hu457][Hu79][Hu249][Hu245][Hu433][Hu316][Hu180][Hu458][Hu86][Hu225][Hu103][Hu60][Hu96][Hu119][Hu129][Hu356][Hu218][Hu4][Hu259][Hu392][Hu490][Hu75][Hu488][Hu166][Hu65][Hu171][Hu60][Hu7][Hu54][Hu85][Hu361]" + assert encoded_string == expected