mohammadmahdinouri commited on
Commit
43acf85
·
1 Parent(s): 2955de7

added codes and tokenizers

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CODE_OF_CONDUCT.md +80 -0
  2. CONTRIBUTING.md +31 -0
  3. LICENSE +116 -0
  4. MODEL_CARD.md +81 -0
  5. README.md +55 -0
  6. assets/spiritlm_overview.png +0 -0
  7. checkpoints/README.md +52 -0
  8. checkpoints/speech_tokenizer/hifigan_spiritlm_base/config.json +55 -0
  9. checkpoints/speech_tokenizer/hifigan_spiritlm_base/generator.pt +3 -0
  10. checkpoints/speech_tokenizer/hifigan_spiritlm_base/speakers.txt +4 -0
  11. checkpoints/speech_tokenizer/hifigan_spiritlm_base/styles.txt +34 -0
  12. checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/config.json +60 -0
  13. checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/generator.pt +3 -0
  14. checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/speakers.txt +4 -0
  15. checkpoints/speech_tokenizer/hubert_25hz/L11_quantizer_500.pt +3 -0
  16. checkpoints/speech_tokenizer/hubert_25hz/mhubert_base_25hz.pt +3 -0
  17. checkpoints/speech_tokenizer/style_encoder_w2v2/config.json +321 -0
  18. checkpoints/speech_tokenizer/style_encoder_w2v2/pytorch_model.bin +3 -0
  19. checkpoints/speech_tokenizer/vqvae_f0_quantizer/config.yaml +59 -0
  20. checkpoints/speech_tokenizer/vqvae_f0_quantizer/model.pt +3 -0
  21. data/examples/pred.jsonl +5 -0
  22. data/examples/ref.jsonl +5 -0
  23. env.yml +19 -0
  24. examples/audio/7143-88743-0029.flac +0 -0
  25. examples/distributed_inference_recipe/multi_nodes.slurm +24 -0
  26. examples/distributed_inference_recipe/run_dist.py +89 -0
  27. examples/speech_generation/spirit_model.ipynb +0 -0
  28. examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb +0 -0
  29. requirements.dev.txt +1 -0
  30. requirements.txt +9 -0
  31. setup.py +60 -0
  32. spiritlm.egg-info/PKG-INFO +84 -0
  33. spiritlm.egg-info/SOURCES.txt +32 -0
  34. spiritlm.egg-info/dependency_links.txt +1 -0
  35. spiritlm.egg-info/not-zip-safe +1 -0
  36. spiritlm.egg-info/requires.txt +15 -0
  37. spiritlm.egg-info/top_level.txt +2 -0
  38. spiritlm/__init__.py +5 -0
  39. spiritlm/__pycache__/__init__.cpython-310.pyc +0 -0
  40. spiritlm/eval/README.md +92 -0
  41. spiritlm/eval/eval_stsp.py +87 -0
  42. spiritlm/eval/load_data.py +50 -0
  43. spiritlm/eval/stsp/few_shot_prompt.py +101 -0
  44. spiritlm/eval/stsp/predict_stsp.py +299 -0
  45. spiritlm/eval/stsp/sanity_check_download.py +30 -0
  46. spiritlm/eval/stsp/sentiment_classifiers.py +37 -0
  47. spiritlm/eval/stsp/stsp_constants.py +12 -0
  48. spiritlm/eval/stsp/utils.py +122 -0
  49. spiritlm/eval/utils.py +17 -0
  50. spiritlm/model/README.md +82 -0
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to spiritlm
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to spiritlm, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FAIR Noncommercial Research License
2
+ Last Updated: October 18, 2024
3
+
4
+ “Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
5
+
6
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
7
+
8
+ “Documentation” means the specifications, manuals and documentation accompanying
9
+ Research Materials distributed by Meta.
10
+
11
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
12
+
13
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
14
+
15
+ “Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
16
+
17
+ “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
18
+
19
+ By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
20
+
21
+
22
+ 1. License Rights and Redistribution.
23
+
24
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
25
+
26
+ b. Redistribution and Use.
27
+ i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
28
+
29
+ ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
30
+
31
+ iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
32
+
33
+ iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
34
+
35
+ 2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
36
+
37
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
38
+
39
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
40
+
41
+ 5. Intellectual Property.
42
+
43
+ a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
44
+
45
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
46
+
47
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
48
+
49
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
50
+
51
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at [https://github.com/facebookresearch/spiritlm/blob/main/LICENSE]; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
52
+
53
+
54
+ FAIR Acceptable Use Policy
55
+
56
+ The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
57
+
58
+ As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials.
59
+
60
+
61
+ Prohibited Uses
62
+
63
+ You agree you will not use, or allow others to use, Research Materials to:
64
+
65
+ 1. Violate the law or others’ rights, including to:
66
+ a. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
67
+ i. Violence or terrorism
68
+ ii. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
69
+ iii. Human trafficking, exploitation, and sexual violence
70
+ iv. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
71
+ v. Sexual solicitation
72
+ iv. Any other criminal activity
73
+
74
+ b. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
75
+
76
+ c. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
77
+
78
+ d. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
79
+
80
+ e. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
81
+
82
+ f. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials
83
+
84
+ g. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
85
+
86
+ 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
87
+
88
+ a. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
89
+
90
+ b. Guns and illegal weapons (including weapon development)
91
+
92
+ c. Illegal drugs and regulated/controlled substances
93
+
94
+ d. Operation of critical infrastructure, transportation technologies, or heavy machinery
95
+
96
+ e. Self-harm or harm to others, including suicide, cutting, and eating disorders
97
+
98
+ f. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
99
+
100
+ 3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following:
101
+
102
+ a. Generating, promoting, or furthering fraud or the creation or promotion of disinformation
103
+
104
+ b. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
105
+
106
+ c. Generating, promoting, or further distributing spam
107
+
108
+ d. Impersonating another individual without consent, authorization, or legal right
109
+
110
+ e. Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials o are human-generated
111
+
112
+ f. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
113
+
114
+ 4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
115
+
116
+ Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform].
MODEL_CARD.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Meta Spirit LM Model Card
2
+
3
+ ## Model Details
4
+
5
+ *Note: Use of this model is governed by the FAIR Noncommercial Research License.*
6
+
7
+ Spirit LM is a multimodal language model that freely mixes text and speech. The model can be prompted with either text or speech and is capable of generating outputs in either modality, while preserving the expressivity of the input prompt. The model is also able to learn new tasks across modalities such as automatic speech recognition, text-to-speech, and speech classification in a few-shot manner.
8
+
9
+ ## Model Developers
10
+ Meta
11
+
12
+ ## Variations
13
+ Spirit LM comes in two versions: Spirit LM Base that uses speech phonetic tokens and Spirit LM Expressive that models expressivity using pitch and style tokens in addition to the phonetic tokens.
14
+
15
+ ## Input
16
+ Models input text or speech or a mixed sequence of the two.
17
+
18
+ ## Output
19
+ Models generate text or speech or a mixed sequence of the two.
20
+
21
+ ## Model Architecture
22
+ ### Speech Tokenizer
23
+ Spirit LM uses 3 types of speech tokenizers: Phonetic Tokenizer (HuBERT), Pitch Tokenizer (VQ-VAE) and Style Tokenizer (Speechprop or Wav2vec2). We use Hifi-GAN to convert the speech tokens back to audio.
24
+
25
+ It is worth noting that in the associated paper, for Spirit LM Expressive, we used Speechprop to extract style tokens, while we use a Wav2vec2 model to extract style tokens in this release.
26
+
27
+ | | Model | Parameters | Input | Output |
28
+ |------------------------|--------------------------|------------|---------------------|--------------------|
29
+ | Phonetic Tokenizer | HuBERT+LinearQuantizer | 96M | Waveform | Phonetic Tokens |
30
+ | Pitch Tokenizer | VQ-VAE | 0.2M | Extracted F0 | Pitch Tokens |
31
+ | Style Tokenizer | Wav2vec2+LinearProjection| 95M | Waveform | Style Tokens |
32
+ | Base Speech Decoder | Hifi-GAN | 14M | Phonetic Tokens | Waveform |
33
+ | Expressive Speech Decoder | Hifi-GAN | 15M | Phonetic, Pitch, Style Tokens | Waveform
34
+
35
+ ### Language Model
36
+ Spirit LM is initialized from the Llama-2 7B model.
37
+
38
+ | | Architecture | Parameters | Input/Output Tokens | Vocab Size |
39
+ |----------------------|----------------|------------|----------------------------------------------------------|------------|
40
+ | Spirit LM Base | Llama-2 7B | 7B | Text Tokens, Phonetic Tokens | 32512 |
41
+ | Spirit LM Expressive | Llama-2 7B | 7B | Text Tokens, Phonetic Tokens, Pitch Tokens, Style Tokens | 32768 |
42
+
43
+ ### Release Date
44
+ The models were trained between October and December 2023. The research paper was released on February 8th 2024. We released the model on October 18th 2024.
45
+
46
+ ### Status
47
+ This is a static model trained on an offline dataset.
48
+
49
+ ### License
50
+ We release the model under the FAIR Noncommercial Research License found in the [LICENSE](LICENSE) file in the root directory of this repo.
51
+
52
+ ### Research Paper
53
+ More information can be found in the paper ["SpiRit-LM: Interleaved Spoken and Written Language Model"](https://arxiv.org/pdf/2402.05755.pdf).
54
+
55
+ ## Hardware and Software
56
+ ### Training Factors
57
+ We used custom training libraries. The training of the released models has been performed on Meta’s Research Clusters.
58
+
59
+ The training of each model (Spirit LM Base and Spirit LM Expressive) takes 21K GPU hours of computation on hardware of type A100-80GB (TDP of 350-400W), not including the training of Llama-2.
60
+
61
+ ## Training Data
62
+ We trained the models on a combination of text-only datasets, speech-only datasets and aligned speech-text datasets. All the speech datasets are publicly available. Here are the statistics of the datasets we used:
63
+
64
+ | | Hours | Speech Tokens | Text Tokens |
65
+ |--------------|-------|---------------|-------------|
66
+ | Speech-only | 458K | 28.2B | - |
67
+ | Speech+Text | 111K | 7.0B | 1.4B |
68
+ | Text-only | - | - | 307B |
69
+
70
+ ## Evaluation Results
71
+ See evaluations for our models and detailed ablations in Section 4 and 5, and safety evaluations in Section 6 of the [research paper](https://arxiv.org/pdf/2402.05755.pdf).
72
+
73
+ ## Intended Use
74
+ ### Intended Use Cases
75
+ Spirit LM is intended for noncommercial research use in English.
76
+
77
+ ### Out-of-Scope Uses
78
+ Use in any manner that violates applicable laws or regulations (including trade compliance laws). Use in languages other than English. Use in any other way that is prohibited by the FAIR Noncommercial Research License and Acceptable Use Policy.
79
+
80
+ ## Ethical Considerations and Limitations
81
+ This model is built on Llama 2 which carries risks with use. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, Llama 2’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. The model’s speech capabilities are designed to analyze speaker agnostic qualities of any input speech and output speech in one of four pre-set voices. The model is meant for use for noncommercial research purposes only and should not be deployed in any consumer-facing applications.
README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Meta Spirit LM: Interleaved Spoken and Written Language Model
2
+
3
+ This repository contains the model weights, inference code and evaluation scripts for the Spirit LM [paper](https://arxiv.org/pdf/2402.05755.pdf). You can find more generation samples on our [demo page](https://speechbot.github.io/spiritlm/).
4
+
5
+ ## Spirit LM Model Overview
6
+ <img src="assets/spiritlm_overview.png">
7
+
8
+ ## Installation Setup
9
+ ### Conda
10
+ ```
11
+ conda env create -f env.yml
12
+ pip install -e '.[eval]'
13
+
14
+ ```
15
+ ### Pip
16
+ ```
17
+ pip install -e '.[eval]'
18
+ ```
19
+
20
+ ### Dev
21
+ (Optionally, use only if you want to run the tests.)
22
+ ```
23
+ pip install -e '.[dev]'
24
+ ```
25
+
26
+ ## Checkpoints Setup
27
+ See [checkpoints/README.md](checkpoints/README.md)
28
+
29
+ ## Quick Start
30
+ ### Speech Tokenization
31
+ See [spiritlm/speech_tokenizer/README.md](spiritlm/speech_tokenizer/README.md)
32
+ ### Spirit LM Generation
33
+ See [spiritlm/model/README.md](spiritlm/model/README.md)
34
+ ### Speech-Text Sentiment Preservation benchmark (STSP)
35
+ See [spiritlm/eval/README.md](spiritlm/eval/README.md)
36
+
37
+ ## Model Card
38
+ More details of the model can be found in [MODEL_CARD.md](MODEL_CARD.md).
39
+
40
+ ## License
41
+ The present code is provided under the **FAIR Noncommercial Research License** found in [LICENSE](LICENSE).
42
+
43
+ ## Citation
44
+ ```
45
+ @misc{nguyen2024spiritlminterleavedspokenwritten,
46
+ title={SpiRit-LM: Interleaved Spoken and Written Language Model},
47
+ author={Tu Anh Nguyen and Benjamin Muller and Bokai Yu and Marta R. Costa-jussa and Maha Elbayad and Sravya Popuri and Paul-Ambroise Duquenne and Robin Algayres and Ruslan Mavlyutov and Itai Gat and Gabriel Synnaeve and Juan Pino and Benoit Sagot and Emmanuel Dupoux},
48
+ year={2024},
49
+ eprint={2402.05755},
50
+ archivePrefix={arXiv},
51
+ primaryClass={cs.CL},
52
+ url={https://arxiv.org/abs/2402.05755},
53
+ }
54
+ ```
55
+
assets/spiritlm_overview.png ADDED
checkpoints/README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Spirit LM Checkpoints
2
+
3
+ ## Download Checkpoints
4
+ To access and download Spirit LM Checkpoints, please request the model artifacts in this link:
5
+
6
+ [https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/](https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/)
7
+
8
+ Upon approval, you will then receive an email with download links to each model artifact.
9
+
10
+ Please note that Spirit LM is made available under the **FAIR Noncommercial Research License**
11
+ found in the [LICENSE](../LICENSE) file in the root directory of this source tree and Acceptable Use Policy.
12
+
13
+ ## Structure
14
+ The checkpoints directory should look like this:
15
+ ```
16
+ checkpoints/
17
+ ├── README.md
18
+ ├── speech_tokenizer
19
+ │   ├── hifigan_spiritlm_base
20
+ │   │   ├── config.json
21
+ │   │   ├── generator.pt
22
+ │   │   ├── speakers.txt
23
+ │   │   └── styles.txt
24
+ │   ├── hifigan_spiritlm_expressive_w2v2
25
+ │   │   ├── config.json
26
+ │   │   ├── generator.pt
27
+ │   │   └── speakers.txt
28
+ │   ├── hubert_25hz
29
+ │   │   ├── L11_quantizer_500.pt
30
+ │   │   └── mhubert_base_25hz.pt
31
+ │   ├── style_encoder_w2v2
32
+ │   │   ├── config.json
33
+ │   │   └── pytorch_model.bin
34
+ │   └── vqvae_f0_quantizer
35
+ │   ├── config.yaml
36
+ │   └── model.pt
37
+ └── spiritlm_model
38
+ ├── spirit-lm-base-7b
39
+ │   ├── config.json
40
+ │   ├── generation_config.json
41
+ │   ├── pytorch_model.bin
42
+ │   ├── special_tokens_map.json
43
+ │   ├── tokenizer_config.json
44
+ │   └── tokenizer.model
45
+ └── spirit-lm-expressive-7b
46
+ ├── config.json
47
+ ├── generation_config.json
48
+ ├── pytorch_model.bin
49
+ ├── special_tokens_map.json
50
+ ├── tokenizer_config.json
51
+ └── tokenizer.model
52
+ ```
checkpoints/speech_tokenizer/hifigan_spiritlm_base/config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+
3
+ "resblock": "1",
4
+ "num_gpus": 0,
5
+ "batch_size": 16,
6
+ "learning_rate": 0.0002,
7
+ "adam_b1": 0.8,
8
+ "adam_b2": 0.99,
9
+ "lr_decay": 0.999,
10
+ "seed": 1234,
11
+
12
+ "upsample_rates": [5,4,4,4,2],
13
+ "upsample_kernel_sizes": [11,8,8,8,4],
14
+ "upsample_initial_channel": 512,
15
+ "resblock_kernel_sizes": [3,7,11],
16
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
17
+ "num_embeddings": 501,
18
+ "embedding_dim": 128,
19
+ "model_in_dim": 384,
20
+
21
+ "segment_size": 8960,
22
+ "code_hop_size": 640,
23
+ "f0": false,
24
+ "num_mels": 80,
25
+ "num_freq": 1025,
26
+ "n_fft": 1024,
27
+ "hop_size": 256,
28
+ "win_size": 1024,
29
+
30
+ "multispkr": "from_input_file",
31
+ "num_speakers": 4,
32
+ "multistyle": "from_input_file",
33
+ "num_styles": 34,
34
+
35
+ "dur_prediction_weight": 1.0,
36
+ "dur_predictor_params": {
37
+ "encoder_embed_dim": 128,
38
+ "var_pred_hidden_dim": 128,
39
+ "var_pred_kernel_size": 3,
40
+ "var_pred_dropout": 0.5
41
+ },
42
+
43
+ "sampling_rate": 16000,
44
+
45
+ "fmin": 0,
46
+ "fmax": 8000,
47
+ "fmax_for_loss": null,
48
+
49
+ "num_workers": 4,
50
+
51
+ "dist_config": {
52
+ "dist_backend": "nccl",
53
+ "dist_url": "env://"
54
+ }
55
+ }
checkpoints/speech_tokenizer/hifigan_spiritlm_base/generator.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d66c49067aeff93b14b038f143c2dc9ed981671956512d8e702897416f13c459
3
+ size 57512631
checkpoints/speech_tokenizer/hifigan_spiritlm_base/speakers.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ex04
2
+ ex02
3
+ ex03
4
+ ex01
checkpoints/speech_tokenizer/hifigan_spiritlm_base/styles.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ read-default
2
+ read-enunciated
3
+ read-confused
4
+ read-laughing
5
+ read-whisper
6
+ read-sad
7
+ read-happy
8
+ conv-projected
9
+ conv-default
10
+ conv-sympathetic
11
+ conv-fast
12
+ conv-disgusted
13
+ conv-laughing
14
+ conv-calm
15
+ conv-sarcastic
16
+ conv-whisper
17
+ conv-angry
18
+ conv-sad
19
+ conv-happy
20
+ conv-enunciated
21
+ conv-awe
22
+ read-singing
23
+ conv-confused
24
+ conv-fearful
25
+ conv-narration
26
+ conv-sleepy
27
+ conv-child
28
+ conv-animal
29
+ conv-childdir
30
+ conv-animaldir
31
+ conv-bored
32
+ conv-desire
33
+ conv-nonverbal
34
+ read-narration
checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+
3
+ "resblock": "1",
4
+ "num_gpus": 0,
5
+ "batch_size": 128,
6
+ "learning_rate": 0.0002,
7
+ "adam_b1": 0.8,
8
+ "adam_b2": 0.99,
9
+ "lr_decay": 0.999,
10
+ "seed": 1234,
11
+
12
+ "upsample_rates": [5,4,4,4,2],
13
+ "upsample_kernel_sizes": [11,8,8,8,4],
14
+ "upsample_initial_channel": 512,
15
+ "resblock_kernel_sizes": [3,7,11],
16
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
17
+
18
+ "multispkr": "from_input_file",
19
+ "multistyle": null,
20
+
21
+ "dur_prediction_weight": 1.0,
22
+ "dur_predictor_params": {
23
+ "encoder_embed_dim": 128,
24
+ "var_pred_hidden_dim": 128,
25
+ "var_pred_kernel_size": 3,
26
+ "var_pred_dropout": 0.5
27
+ },
28
+
29
+ "segment_size": 17920,
30
+ "code_hop_size": 640,
31
+ "f0_hop_size": 1280,
32
+ "style_hop_size": 16000,
33
+
34
+ "num_embeddings": 501,
35
+ "num_f0_tokens": 64,
36
+ "num_style_tokens": 100,
37
+ "num_speakers": 4,
38
+
39
+ "embedding_dim": 128,
40
+ "model_in_dim": 512,
41
+
42
+ "num_mels": 80,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 16000,
49
+
50
+ "fmin": 0,
51
+ "fmax": 8000,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "env://"
59
+ }
60
+ }
checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/generator.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9de97cbc336e6113c27f17560988a22e76df25af6a8e944ad01676aabec31326
3
+ size 59414584
checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/speakers.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ex04
2
+ ex02
3
+ ex03
4
+ ex01
checkpoints/speech_tokenizer/hubert_25hz/L11_quantizer_500.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06b408c0a487f0218e8aba52ca30d9de54e0c36af8bebfd151265647d221080b
3
+ size 5222060
checkpoints/speech_tokenizer/hubert_25hz/mhubert_base_25hz.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1421f060cf92b9d2ea72dcecd7f30ce549e1112e24b62b43f2ec3026301051bb
3
+ size 383333938
checkpoints/speech_tokenizer/style_encoder_w2v2/config.json ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/wav2vec2-base",
3
+ "activation_dropout": 0.0,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForPooledSequenceClassification"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 256,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": false,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "sum",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": false,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_norm": "group",
51
+ "feat_proj_dropout": 0.1,
52
+ "feat_quantizer_dropout": 0.0,
53
+ "final_dropout": 0.0,
54
+ "freeze_feat_extract_train": true,
55
+ "hidden_act": "gelu",
56
+ "hidden_dropout": 0.1,
57
+ "hidden_size": 768,
58
+ "id2label": {
59
+ "0": "0",
60
+ "1": "1",
61
+ "10": "10",
62
+ "11": "11",
63
+ "12": "12",
64
+ "13": "13",
65
+ "14": "14",
66
+ "15": "15",
67
+ "16": "16",
68
+ "17": "17",
69
+ "18": "18",
70
+ "19": "19",
71
+ "2": "2",
72
+ "20": "20",
73
+ "21": "21",
74
+ "22": "22",
75
+ "23": "23",
76
+ "24": "24",
77
+ "25": "25",
78
+ "26": "26",
79
+ "27": "27",
80
+ "28": "28",
81
+ "29": "29",
82
+ "3": "3",
83
+ "30": "30",
84
+ "31": "31",
85
+ "32": "32",
86
+ "33": "33",
87
+ "34": "34",
88
+ "35": "35",
89
+ "36": "36",
90
+ "37": "37",
91
+ "38": "38",
92
+ "39": "39",
93
+ "4": "4",
94
+ "40": "40",
95
+ "41": "41",
96
+ "42": "42",
97
+ "43": "43",
98
+ "44": "44",
99
+ "45": "45",
100
+ "46": "46",
101
+ "47": "47",
102
+ "48": "48",
103
+ "49": "49",
104
+ "5": "5",
105
+ "50": "50",
106
+ "51": "51",
107
+ "52": "52",
108
+ "53": "53",
109
+ "54": "54",
110
+ "55": "55",
111
+ "56": "56",
112
+ "57": "57",
113
+ "58": "58",
114
+ "59": "59",
115
+ "6": "6",
116
+ "60": "60",
117
+ "61": "61",
118
+ "62": "62",
119
+ "63": "63",
120
+ "64": "64",
121
+ "65": "65",
122
+ "66": "66",
123
+ "67": "67",
124
+ "68": "68",
125
+ "69": "69",
126
+ "7": "7",
127
+ "70": "70",
128
+ "71": "71",
129
+ "72": "72",
130
+ "73": "73",
131
+ "74": "74",
132
+ "75": "75",
133
+ "76": "76",
134
+ "77": "77",
135
+ "78": "78",
136
+ "79": "79",
137
+ "8": "8",
138
+ "80": "80",
139
+ "81": "81",
140
+ "82": "82",
141
+ "83": "83",
142
+ "84": "84",
143
+ "85": "85",
144
+ "86": "86",
145
+ "87": "87",
146
+ "88": "88",
147
+ "89": "89",
148
+ "9": "9",
149
+ "90": "90",
150
+ "91": "91",
151
+ "92": "92",
152
+ "93": "93",
153
+ "94": "94",
154
+ "95": "95",
155
+ "96": "96",
156
+ "97": "97",
157
+ "98": "98",
158
+ "99": "99"
159
+ },
160
+ "initializer_range": 0.02,
161
+ "intermediate_size": 3072,
162
+ "label2id": {
163
+ "0": "0",
164
+ "1": "1",
165
+ "10": "10",
166
+ "11": "11",
167
+ "12": "12",
168
+ "13": "13",
169
+ "14": "14",
170
+ "15": "15",
171
+ "16": "16",
172
+ "17": "17",
173
+ "18": "18",
174
+ "19": "19",
175
+ "2": "2",
176
+ "20": "20",
177
+ "21": "21",
178
+ "22": "22",
179
+ "23": "23",
180
+ "24": "24",
181
+ "25": "25",
182
+ "26": "26",
183
+ "27": "27",
184
+ "28": "28",
185
+ "29": "29",
186
+ "3": "3",
187
+ "30": "30",
188
+ "31": "31",
189
+ "32": "32",
190
+ "33": "33",
191
+ "34": "34",
192
+ "35": "35",
193
+ "36": "36",
194
+ "37": "37",
195
+ "38": "38",
196
+ "39": "39",
197
+ "4": "4",
198
+ "40": "40",
199
+ "41": "41",
200
+ "42": "42",
201
+ "43": "43",
202
+ "44": "44",
203
+ "45": "45",
204
+ "46": "46",
205
+ "47": "47",
206
+ "48": "48",
207
+ "49": "49",
208
+ "5": "5",
209
+ "50": "50",
210
+ "51": "51",
211
+ "52": "52",
212
+ "53": "53",
213
+ "54": "54",
214
+ "55": "55",
215
+ "56": "56",
216
+ "57": "57",
217
+ "58": "58",
218
+ "59": "59",
219
+ "6": "6",
220
+ "60": "60",
221
+ "61": "61",
222
+ "62": "62",
223
+ "63": "63",
224
+ "64": "64",
225
+ "65": "65",
226
+ "66": "66",
227
+ "67": "67",
228
+ "68": "68",
229
+ "69": "69",
230
+ "7": "7",
231
+ "70": "70",
232
+ "71": "71",
233
+ "72": "72",
234
+ "73": "73",
235
+ "74": "74",
236
+ "75": "75",
237
+ "76": "76",
238
+ "77": "77",
239
+ "78": "78",
240
+ "79": "79",
241
+ "8": "8",
242
+ "80": "80",
243
+ "81": "81",
244
+ "82": "82",
245
+ "83": "83",
246
+ "84": "84",
247
+ "85": "85",
248
+ "86": "86",
249
+ "87": "87",
250
+ "88": "88",
251
+ "89": "89",
252
+ "9": "9",
253
+ "90": "90",
254
+ "91": "91",
255
+ "92": "92",
256
+ "93": "93",
257
+ "94": "94",
258
+ "95": "95",
259
+ "96": "96",
260
+ "97": "97",
261
+ "98": "98",
262
+ "99": "99"
263
+ },
264
+ "layer_norm_eps": 1e-05,
265
+ "layerdrop": 0.0,
266
+ "mask_channel_length": 10,
267
+ "mask_channel_min_space": 1,
268
+ "mask_channel_other": 0.0,
269
+ "mask_channel_prob": 0.0,
270
+ "mask_channel_selection": "static",
271
+ "mask_feature_length": 10,
272
+ "mask_feature_min_masks": 0,
273
+ "mask_feature_prob": 0.0,
274
+ "mask_time_length": 10,
275
+ "mask_time_min_masks": 2,
276
+ "mask_time_min_space": 1,
277
+ "mask_time_other": 0.0,
278
+ "mask_time_prob": 0.05,
279
+ "mask_time_selection": "static",
280
+ "model_type": "wav2vec2",
281
+ "no_mask_channel_overlap": false,
282
+ "no_mask_time_overlap": false,
283
+ "num_adapter_layers": 3,
284
+ "num_attention_heads": 12,
285
+ "num_codevector_groups": 2,
286
+ "num_codevectors_per_group": 320,
287
+ "num_conv_pos_embedding_groups": 16,
288
+ "num_conv_pos_embeddings": 128,
289
+ "num_feat_extract_layers": 7,
290
+ "num_hidden_layers": 12,
291
+ "num_negatives": 100,
292
+ "output_hidden_size": 768,
293
+ "pad_token_id": 0,
294
+ "proj_codevector_dim": 256,
295
+ "tdnn_dilation": [
296
+ 1,
297
+ 2,
298
+ 3,
299
+ 1,
300
+ 1
301
+ ],
302
+ "tdnn_dim": [
303
+ 512,
304
+ 512,
305
+ 512,
306
+ 512,
307
+ 1500
308
+ ],
309
+ "tdnn_kernel": [
310
+ 5,
311
+ 3,
312
+ 3,
313
+ 1,
314
+ 1
315
+ ],
316
+ "torch_dtype": "float32",
317
+ "transformers_version": "4.25.1",
318
+ "use_weighted_layer_sum": false,
319
+ "vocab_size": 32,
320
+ "xvector_output_dim": 512
321
+ }
checkpoints/speech_tokenizer/style_encoder_w2v2/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:819b7a87bfcf4041252c3f1825915924e10d5a06a2a66da546beaca98c7ab8bc
3
+ size 378451177
checkpoints/speech_tokenizer/vqvae_f0_quantizer/config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ seed: 1234
3
+
4
+ # Data
5
+ f0_path: ''
6
+ p_train: 0.95
7
+ min_frames: null
8
+ batch_size: 128
9
+ features: f0_interp,vuv
10
+ out_features: norm_f0_interp,vuv
11
+ segment_size: null
12
+ segment_multi: 16
13
+ num_workers: 4
14
+ vuv_scale: 2
15
+ speaker_stats: ''
16
+ recon_loss_fn: l1_loss
17
+
18
+
19
+ # Optimization
20
+ learning_rate: 0.0002
21
+ adam_b1: 0.8
22
+ adam_b2: 0.99
23
+ lr_decay: 0.999
24
+ lambda_commit: 0.02
25
+
26
+ # VQ params
27
+ vq_params:
28
+ l_bins: 64
29
+ emb_width: 128
30
+ mu: 0.99
31
+ levels: 1
32
+
33
+ # Encoder params
34
+ encoder_params:
35
+ input_emb_width: 2
36
+ output_emb_width: 128
37
+ levels: 1
38
+ downs_t:
39
+ - 4
40
+ strides_t:
41
+ - 2
42
+ width: 32
43
+ depth: 4
44
+ m_conv: 1.0
45
+ dilation_growth_rate: 3
46
+
47
+ # Decoder params
48
+ decoder_params:
49
+ input_emb_width: 2
50
+ output_emb_width: 128
51
+ levels: 1
52
+ downs_t:
53
+ - 4
54
+ strides_t:
55
+ - 2
56
+ width: 32
57
+ depth: 4
58
+ m_conv: 1.0
59
+ dilation_growth_rate: 3
checkpoints/speech_tokenizer/vqvae_f0_quantizer/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4321b0c19b47279ab3d76c4b6e85bbc439156a4dd6478919856accaf4180382
3
+ size 2600601
data/examples/pred.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"pred": "angry", "id": 4792320029370491913}
2
+ {"pred": "neutral", "id": -5682350483296949563}
3
+ {"pred": "amused", "id": -8754508989367964614}
4
+ {"pred": "angry", "id": -9018665079841831624}
5
+ {"pred": "neutral", "id": 1159246029716120600}
data/examples/ref.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"emotion": "angry", "sentiment": "negative", "wav_path": "emov/sam/Angry/anger_281-308_0286.wav", "split": "test", "speaker": "sam", "id": 4792320029370491913}
2
+ {"emotion": "neutral", "sentiment": "neutral", "wav_path": "emov/sam/Neutral/neutral_281-308_0286.wav", "split": "test", "speaker": "sam", "id": -5682350483296949563}
3
+ {"emotion": "amused", "sentiment": "positive", "wav_path": "emov/sam/Amused/amused_281-308_0286.wav", "split": "test", "speaker": "sam", "id": -8754508989367964614}
4
+ {"emotion": "angry", "sentiment": "negative", "wav_path": "emov/jenie/Angry/anger_57-84_0084.wav", "split": "test", "speaker": "jenie", "id": -9018665079841831624}
5
+ {"emotion": "neutral", "sentiment": "neutral", "wav_path": "emov/jenie/Neutral/neutral_57-84_0084.wav", "split": "test", "speaker": "jenie", "id": 1159246029716120600}
env.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: spiritlm_test
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.9
7
+ - pip
8
+ - pytorch-cuda=11.8
9
+ - pytorch
10
+ - torchaudio
11
+ - pip:
12
+ - omegaconf==2.2.0
13
+ - librosa~=0.10
14
+ - local-attention~=1.9
15
+ - encodec~=0.1
16
+ - transformers
17
+ - fairscale~=0.4
18
+ - sentencepiece
19
+ - torchfcpe~=0.0.4
examples/audio/7143-88743-0029.flac ADDED
Binary file (147 kB). View file
 
examples/distributed_inference_recipe/multi_nodes.slurm ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the FAIR Noncommercial Research License
7
+ # found in the LICENSE file in the root directory of this source tree.
8
+
9
+ #SBATCH --job-name=spiritlm
10
+ #SBATCH --ntasks-per-node=1
11
+ #SBATCH --gpus-per-node=8
12
+ #SBATCH --nodes=2
13
+ #SBATCH --cpus-per-task=12
14
+ #SBATCH --output=./logs/%j.stdout
15
+ #SBATCH --error=./logs/%j.stderr
16
+ #SBATCH --time=01:00:00
17
+
18
+ set -e
19
+
20
+ srun bash -c 'torchrun --nnodes $SLURM_JOB_NUM_NODES --nproc-per-node $SLURM_GPUS_ON_NODE \
21
+ --node-rank $SLURM_PROCID \
22
+ --master-addr $(scontrol show hostnames $SLURM_NODELIST | head -n1) \
23
+ --master-port 12345 \
24
+ examples/distributed_inference_recipe/run_dist.py'
examples/distributed_inference_recipe/run_dist.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Usage example:
9
+
10
+ cd {SPIRITLM ROOT FOLDER}
11
+ export PYTHONPATH=.
12
+
13
+ Single node, multi-gpus:
14
+ (Assume that your machine has 8 GPUs)
15
+ torchrun --nnodes 1 --nproc-per-node 8 examples/distributed_inference_recipe/run_dist.py
16
+
17
+ Multi-nodes, multi-gpus:
18
+ (2 nodes, 8 GPUs for eahc node, via sbatch)
19
+ mkdir -p logs
20
+ sbatch examples/distributed_inference_recipe/multi_nodes.slurm
21
+ """
22
+
23
+ import os
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ import torchaudio
28
+ from spiritlm.model.spiritlm_model import (
29
+ ContentType,
30
+ GenerationInput,
31
+ OutputModality,
32
+ Spiritlm,
33
+ )
34
+ from torch.utils.data import TensorDataset
35
+ from torch.utils.data.distributed import DistributedSampler
36
+ from transformers import GenerationConfig, set_seed
37
+
38
+
39
+ def run(seed: int = 0):
40
+ world_size = int(os.environ["WORLD_SIZE"])
41
+ world_rank = int(os.environ["RANK"])
42
+ print(
43
+ f"Running distributed inference with world_size: {world_size}, world_rank: {world_rank}"
44
+ )
45
+ dist.init_process_group("nccl", rank=world_rank, world_size=world_size)
46
+
47
+ set_seed(seed)
48
+
49
+ wav = torchaudio.load("examples/audio/7143-88743-0029.flac")[0].squeeze()
50
+
51
+ # fake repeated dataset
52
+ dataset = TensorDataset(wav.repeat(32, 1))
53
+
54
+ sampler = DistributedSampler(dataset=dataset)
55
+ loader = torch.utils.data.DataLoader(
56
+ dataset=dataset,
57
+ batch_size=1, # don't change
58
+ sampler=sampler,
59
+ num_workers=4,
60
+ )
61
+
62
+ spirit_lm = Spiritlm("spirit-lm-expressive-7b")
63
+
64
+ for _, data in enumerate(loader):
65
+ outs = spirit_lm.generate(
66
+ output_modality=OutputModality.ARBITRARY,
67
+ interleaved_inputs=[
68
+ GenerationInput(
69
+ content=data[0], # 0 because of batch size 1
70
+ content_type=ContentType.SPEECH,
71
+ )
72
+ ],
73
+ generation_config=GenerationConfig(
74
+ temperature=0.9,
75
+ top_p=0.95,
76
+ max_new_tokens=200,
77
+ do_sample=True,
78
+ ),
79
+ )
80
+ print(f"outs: {outs}")
81
+
82
+
83
+ def setup_env():
84
+ os.environ["OMP_NUM_THREADS"] = "1"
85
+
86
+
87
+ if __name__ == "__main__":
88
+ setup_env()
89
+ run()
examples/speech_generation/spirit_model.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.dev.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pytest
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf>=2.2.0
2
+ librosa>=0.10
3
+ local-attention>=1.9
4
+ encodec>=0.1
5
+ transformers
6
+ fairscale>=0.4
7
+ sentencepiece
8
+ pyarrow>=14.0
9
+ torchfcpe>=0.0.4
setup.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from pathlib import Path
9
+
10
+ from setuptools import find_packages, setup
11
+
12
+ NAME = "spiritlm"
13
+ VERSION = "0.1.0"
14
+ DESCRIPTION = "Interleaved Spoken and Written Language Model"
15
+ URL = "https://github.com/facebookresearch/spiritlm"
16
+ KEYWORDS = [
17
+ "Language Model, Speech Language Model, Multimodal, Crossmodal, Expressivity Modeling"
18
+ ]
19
+ LICENSE = "FAIR Noncommercial Research License"
20
+
21
+
22
+ def _get_long_description():
23
+ with (Path(__file__).parent / "README.md").open(encoding="utf-8") as file:
24
+ long_description = file.read()
25
+ return long_description
26
+
27
+
28
+ def _read_reqs(relpath):
29
+ fullpath = os.path.join(os.path.dirname(__file__), relpath)
30
+ with open(fullpath) as f:
31
+ return [
32
+ s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))
33
+ ]
34
+
35
+
36
+ setup(
37
+ name=NAME,
38
+ version=VERSION,
39
+ description=DESCRIPTION,
40
+ long_description=_get_long_description(),
41
+ long_description_content_type="text/plain",
42
+ url=URL,
43
+ license=LICENSE,
44
+ author="Meta",
45
+ keywords=KEYWORDS,
46
+ classifiers=[
47
+ "Intended Audience :: Science/Research",
48
+ "License :: FAIR Noncommercial Research License",
49
+ "Topic :: Multimedia :: Sound/Audio",
50
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
51
+ ],
52
+ packages=find_packages(),
53
+ zip_safe=False,
54
+ python_requires=">=3.9",
55
+ install_requires=_read_reqs("requirements.txt"),
56
+ extras_require={
57
+ "dev": ["pytest"],
58
+ "eval": ["pandas"],
59
+ },
60
+ )
spiritlm.egg-info/PKG-INFO ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: spiritlm
3
+ Version: 0.1.0
4
+ Summary: Interleaved Spoken and Written Language Model
5
+ Home-page: https://github.com/facebookresearch/spiritlm
6
+ Author: Meta
7
+ License: FAIR Noncommercial Research License
8
+ Keywords: Language Model, Speech Language Model, Multimodal, Crossmodal, Expressivity Modeling
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: License :: FAIR Noncommercial Research License
11
+ Classifier: Topic :: Multimedia :: Sound/Audio
12
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
13
+ Requires-Python: >=3.9
14
+ Description-Content-Type: text/plain
15
+ License-File: LICENSE
16
+ Requires-Dist: omegaconf>=2.2.0
17
+ Requires-Dist: librosa>=0.10
18
+ Requires-Dist: local-attention>=1.9
19
+ Requires-Dist: encodec>=0.1
20
+ Requires-Dist: transformers
21
+ Requires-Dist: fairscale>=0.4
22
+ Requires-Dist: sentencepiece
23
+ Requires-Dist: pyarrow>=14.0
24
+ Requires-Dist: torchfcpe>=0.0.4
25
+ Provides-Extra: dev
26
+ Requires-Dist: pytest; extra == "dev"
27
+ Provides-Extra: eval
28
+ Requires-Dist: pandas; extra == "eval"
29
+
30
+ # Meta Spirit LM: Interleaved Spoken and Written Language Model
31
+
32
+ This repository contains the model weights, inference code and evaluation scripts for the Spirit LM [paper](https://arxiv.org/pdf/2402.05755.pdf). You can find more generation samples on our [demo page](https://speechbot.github.io/spiritlm/).
33
+
34
+ ## Spirit LM Model Overview
35
+ <img src="assets/spiritlm_overview.png">
36
+
37
+ ## Installation Setup
38
+ ### Conda
39
+ ```
40
+ conda env create -f env.yml
41
+ pip install -e '.[eval]'
42
+
43
+ ```
44
+ ### Pip
45
+ ```
46
+ pip install -e '.[eval]'
47
+ ```
48
+
49
+ ### Dev
50
+ (Optionally, use only if you want to run the tests.)
51
+ ```
52
+ pip install -e '.[dev]'
53
+ ```
54
+
55
+ ## Checkpoints Setup
56
+ See [checkpoints/README.md](checkpoints/README.md)
57
+
58
+ ## Quick Start
59
+ ### Speech Tokenization
60
+ See [spiritlm/speech_tokenizer/README.md](spiritlm/speech_tokenizer/README.md)
61
+ ### Spirit LM Generation
62
+ See [spiritlm/model/README.md](spiritlm/model/README.md)
63
+ ### Speech-Text Sentiment Preservation benchmark (STSP)
64
+ See [spiritlm/eval/README.md](spiritlm/eval/README.md)
65
+
66
+ ## Model Card
67
+ More details of the model can be found in [MODEL_CARD.md](MODEL_CARD.md).
68
+
69
+ ## License
70
+ The present code is provided under the **FAIR Noncommercial Research License** found in [LICENSE](LICENSE).
71
+
72
+ ## Citation
73
+ ```
74
+ @misc{nguyen2024spiritlminterleavedspokenwritten,
75
+ title={SpiRit-LM: Interleaved Spoken and Written Language Model},
76
+ author={Tu Anh Nguyen and Benjamin Muller and Bokai Yu and Marta R. Costa-jussa and Maha Elbayad and Sravya Popuri and Paul-Ambroise Duquenne and Robin Algayres and Ruslan Mavlyutov and Itai Gat and Gabriel Synnaeve and Juan Pino and Benoit Sagot and Emmanuel Dupoux},
77
+ year={2024},
78
+ eprint={2402.05755},
79
+ archivePrefix={arXiv},
80
+ primaryClass={cs.CL},
81
+ url={https://arxiv.org/abs/2402.05755},
82
+ }
83
+ ```
84
+
spiritlm.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ spiritlm/__init__.py
5
+ spiritlm.egg-info/PKG-INFO
6
+ spiritlm.egg-info/SOURCES.txt
7
+ spiritlm.egg-info/dependency_links.txt
8
+ spiritlm.egg-info/not-zip-safe
9
+ spiritlm.egg-info/requires.txt
10
+ spiritlm.egg-info/top_level.txt
11
+ spiritlm/model/__init__.py
12
+ spiritlm/model/spiritlm_model.py
13
+ spiritlm/model/utils.py
14
+ spiritlm/speech_tokenizer/__init__.py
15
+ spiritlm/speech_tokenizer/spiritlm_tokenizer.py
16
+ spiritlm/speech_tokenizer/f0/__init__.py
17
+ spiritlm/speech_tokenizer/f0/f0_extractor.py
18
+ spiritlm/speech_tokenizer/f0/f0_tokenizer.py
19
+ spiritlm/speech_tokenizer/f0/vqvae.py
20
+ spiritlm/speech_tokenizer/hifigan/__init__.py
21
+ spiritlm/speech_tokenizer/hifigan/hifigan_vocoder.py
22
+ spiritlm/speech_tokenizer/hubert/__init__.py
23
+ spiritlm/speech_tokenizer/hubert/hubert_tokenizer.py
24
+ spiritlm/speech_tokenizer/hubert/quantizer_model.py
25
+ spiritlm/speech_tokenizer/hubert/hubert_model/__init__.py
26
+ spiritlm/speech_tokenizer/hubert/hubert_model/hubert_model.py
27
+ spiritlm/speech_tokenizer/hubert/hubert_model/wav2vec2_model.py
28
+ spiritlm/speech_tokenizer/style_encoder/__init__.py
29
+ spiritlm/speech_tokenizer/style_encoder/w2v2_encoder.py
30
+ tests/__init__.py
31
+ tests/test_spirit_model.py
32
+ tests/test_tokenizer.py
spiritlm.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
spiritlm.egg-info/not-zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
spiritlm.egg-info/requires.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf>=2.2.0
2
+ librosa>=0.10
3
+ local-attention>=1.9
4
+ encodec>=0.1
5
+ transformers
6
+ fairscale>=0.4
7
+ sentencepiece
8
+ pyarrow>=14.0
9
+ torchfcpe>=0.0.4
10
+
11
+ [dev]
12
+ pytest
13
+
14
+ [eval]
15
+ pandas
spiritlm.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ spiritlm
2
+ tests
spiritlm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
spiritlm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (162 Bytes). View file
 
spiritlm/eval/README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # STSP Evaluation
2
+ The Speech-Text Sentiment Preservation (STSP) benchmark is made of a collection of speech and text prompts in the positive, negative or neutral sentiment.
3
+ Given a spoken or written prompt , the task consists in generating a text or speech sequence of tokens that preserves the sentiment of the prompt.
4
+
5
+ The sentiment of the prompt is evaluated automatically with a sentiment/emotion classifier in speech or text depending of the output modality.
6
+ Based on these, we derive a STSP accuracy score.
7
+
8
+ ## Data Download
9
+ Download the data as well as the speech/text classifier checkpoints via this [link](https://dl.fbaipublicfiles.com/textless_nlp/spiritlm/stsp.tar.gz)
10
+ then extract the data into the folder `{spiritlm ROOT FOLDER}/data/stsp_data`
11
+ ```
12
+ cd {spiritlm ROOT FOLDER}
13
+ mkdir data/stsp_data
14
+ tar -xvzf stsp.tar.gz -C data/stsp_data --strip-components=1
15
+ ```
16
+ Run the following script to check the dataset is all correctly present:
17
+ ```
18
+ python spiritlm/eval/stsp/sanity_check_download.py
19
+ ```
20
+ ## Data structure
21
+ The dataset contains 3 folders:
22
+ - `data`: raw audio files
23
+ - `manifest`: data splits
24
+ - `model`: speech/text classifier checkpoints
25
+ ### Data
26
+ The raw audio files for
27
+ - `emov`: EMOV
28
+ - `expresso/conversational`: EXPRESSO-ASR
29
+ - `expresso/read`: EXPRESSO-READ
30
+
31
+ ### Manifest
32
+ The train/validation/test splits, concretely we have:
33
+
34
+ #### EMOV
35
+ - 1053 records for emov train split at `manifest/emov/emov.train.jsonl`
36
+ - 351 records for emov dev split at `manifest/emov/emov.dev.jsonl`
37
+ - 351 records for emov test split at `manifest/emov/emov.test.jsonl`
38
+
39
+ #### EXPRESSO-ASR
40
+ - 1373 records for EXPRESSO-ASR train split at `manifest/expresso/expresso_asr.train`
41
+ - 479 records for EXPRESSO-ASR dev at `manifest/expresso/expresso_asr.dev.jsonl`
42
+ - 462 records for EXPRESSO-ASR test split at `manifest/expresso/expresso_asr.test.jsonl`
43
+
44
+ #### EXPRESSO-READ
45
+ - 1024 records for EXPRESSO-READ train split at `manifest/expresso/expresso_read.train`
46
+ - 60 records for EXPRESSO-READ dev at `manifest/expresso/expresso_read.dev.jsonl`
47
+ - 54 records for EXPRESSO-READ test split at `manifest/expresso/expresso_read.test.jsonl`
48
+
49
+ #### Few-shot Samples
50
+ The subset from EXPRESSO-ASR training set, used for the few-shot experiments:
51
+ - `s2s.jsonl`: S -> S direction
52
+ - `s2t.jsonl`: S -> T direction
53
+ - `t2t.jsonl`: T -> T direction
54
+ - `t2s.jsonl`: T -> S direction
55
+
56
+ ### Auto-Eval Speech And Text Classifiers
57
+
58
+ The sentiment of the generated sequence is estimated in an auto-eval fashion with Speech and Text classifiers. We point to the [paper](https://arxiv.org/abs/2402.05755) for details on these classifiers.
59
+
60
+
61
+ ## Prediction & Evaluation of Spirit LM on STSP (Speech/Text)
62
+
63
+ ```export PYTHONPATH=.```
64
+
65
+ Set `spiritlm` to the model you want to evaluate: e.g. ```spiritlm=spirit-lm-base-7b``` or ```spiritlm=spirit-lm-expressive-7b```
66
+
67
+ #### Speech to Text
68
+ torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_s_t.jsonl --input_output speech_text
69
+ #### Text to Text
70
+ torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_t_t.jsonl --input_output text_text
71
+ #### Text to Speech
72
+ torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_t_s.jsonl --input_output text_speech
73
+ #### Speech to Speech
74
+ torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_s_s.jsonl --input_output speech_speech
75
+
76
+
77
+ ### Post-hoc Evaluation
78
+
79
+ To evaluate the performance of a model different from SpiritLM, you can use the following evaluation script that takes as input a prediction.jsonl file.
80
+
81
+ ```
82
+ python spiritlm/eval/eval_stsp.py --ref_file $REF_FILE --pred_file $pred_file
83
+ ```
84
+
85
+ e.g.
86
+
87
+ ```
88
+ python spiritlm/eval/eval_stsp.py \
89
+ --ref_file ./data/examples/demo.jsonl \
90
+ --pred_file ./data/examples/pred.jsonl
91
+ > Accuracy: 100.00% for predictions ./data/examples/pred.jsonl
92
+ ```
spiritlm/eval/eval_stsp.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import json
9
+ from typing import Dict, Union
10
+
11
+ import pandas as pd
12
+ from spiritlm.eval.stsp.utils import EMOTION_2_SENTIMENT
13
+
14
+
15
+ def load_pred(predictions):
16
+ ret = {}
17
+ with open(predictions) as f:
18
+ for line in f:
19
+ pred = json.loads(line)
20
+ ret[str(pred["id"])] = pred["pred"]
21
+
22
+ assert sum(1 for _ in open(predictions)) == len(ret)
23
+
24
+ return ret
25
+
26
+
27
+ def eval(
28
+ gold_records: str, predictions: Union[str, Dict], info_data="", label="sentiment"
29
+ ):
30
+ n_gold_records = sum(1 for _ in open(gold_records))
31
+ n_lines_pred = (
32
+ sum(1 for _ in open(predictions))
33
+ if isinstance(predictions, str)
34
+ else len(predictions)
35
+ )
36
+ assert (
37
+ n_gold_records == n_lines_pred
38
+ ), f"Mismatch between prediction ({n_lines_pred} samples in {predictions}) and reference ({n_gold_records} in {gold_records})"
39
+
40
+ pred_dic = load_pred(predictions) if isinstance(predictions, str) else predictions
41
+ scores = []
42
+
43
+ with open(gold_records) as gold:
44
+ for line in gold:
45
+ ref = json.loads(line)
46
+ try:
47
+ if label in ref:
48
+ scores.append(pred_dic[str(ref["id"])] == ref[label])
49
+ else:
50
+ assert label == "sentiment" and "emotion" in ref, ref
51
+ sentiment = EMOTION_2_SENTIMENT[ref["emotion"]]
52
+ scores.append(pred_dic[str(ref["id"])] == sentiment)
53
+ except Exception as e:
54
+ print(
55
+ f"ERROR in matching the predicted labels with the gold ones: {e}: ref['id'] do not match any key in {pred_dic}', {ref['id']}: "
56
+ )
57
+ # TODO: add other metrics if needed : F1 per class, etc.
58
+ report = pd.DataFrame({"Correct": scores})
59
+ if isinstance(predictions, str):
60
+ info_data += f"from {predictions}"
61
+ print(
62
+ f"Accuracy: {(report['Correct']==1).sum()/len(report)*100:0.2f}% for predictions {info_data}"
63
+ )
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = argparse.ArgumentParser()
68
+
69
+ parser.add_argument(
70
+ "--ref_file",
71
+ type=str,
72
+ help="Path to reference record",
73
+ )
74
+ parser.add_argument(
75
+ "--pred_file",
76
+ type=str,
77
+ help="Path to prediction: should be jsonl with each entry {'pred': , 'id': }",
78
+ )
79
+ parser.add_argument(
80
+ "--label",
81
+ type=str,
82
+ default="sentiment",
83
+ help="sentiment or emotion",
84
+ )
85
+ args = parser.parse_args()
86
+
87
+ eval(args.ref_file, args.pred_file, label=args.label)
spiritlm/eval/load_data.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import torchaudio
12
+
13
+
14
+ class SpeechData(torch.utils.data.Dataset):
15
+ def __init__(self, manifest_dir, root_dir=None):
16
+ if root_dir is None:
17
+ root_dir = "."
18
+ self.root_dir = Path(root_dir)
19
+ self.manifest_dir = self.root_dir / manifest_dir
20
+ self.wav_field = "wav_path"
21
+ self.manifest = [json.loads(line.strip()) for line in open(manifest_dir)]
22
+
23
+ def __getitem__(self, idx):
24
+ wav_path = self.root_dir / self.manifest[idx][self.wav_field]
25
+ return {
26
+ "wav": torchaudio.load(wav_path)[0].squeeze(0),
27
+ "id": str(self.manifest[idx]["id"]),
28
+ }
29
+
30
+ def __len__(self):
31
+ return len(self.manifest)
32
+
33
+
34
+ class TextData(torch.utils.data.Dataset):
35
+ def __init__(self, manifest_dir, root_dir=None):
36
+ if root_dir is None:
37
+ root_dir = "."
38
+ self.root_dir = Path(root_dir)
39
+ self.manifest_dir = self.root_dir / manifest_dir
40
+ self.text_field = "asr"
41
+ self.manifest = [json.loads(line.strip()) for line in open(manifest_dir)]
42
+
43
+ def __getitem__(self, idx):
44
+ return {
45
+ "text": self.manifest[idx][self.text_field],
46
+ "id": str(self.manifest[idx]["id"]),
47
+ }
48
+
49
+ def __len__(self):
50
+ return len(self.manifest)
spiritlm/eval/stsp/few_shot_prompt.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Union
9
+
10
+ import pandas as pd
11
+ import torch
12
+ import torchaudio
13
+ from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MANIFEST_ROOT
14
+ from spiritlm.model.spiritlm_model import Spiritlm
15
+
16
+ FEW_SHOT_MANIFEST_DIR = STSP_MANIFEST_ROOT / "few_shot"
17
+ FEW_SHOT_TEMPLATE = "{prompt}{generation}"
18
+
19
+
20
+ def wav_prompt(spiritlm_model: Spiritlm, wav: Union[str, torch.Tensor]) -> str:
21
+ return spiritlm_model.SPEECH_PROMPT_PREFIX + spiritlm_model.speech_tokenizer(wav)
22
+
23
+
24
+ def text_prompt(spiritlm_model: Spiritlm, text: str) -> str:
25
+ return spiritlm_model.TEXT_PROMPT_PREFIX + text
26
+
27
+
28
+ def _load_half_wav(wav_path: str, load_first_half: bool) -> torch.Tensor:
29
+ wav_path = STSP_DATA_ROOT / wav_path
30
+ wav = torchaudio.load(wav_path)[0].squeeze(0)
31
+ size = wav.size()[0]
32
+ half_size = size // 2
33
+ if load_first_half:
34
+ wav = wav[:half_size]
35
+ else:
36
+ wav = wav[half_size:]
37
+ return wav
38
+
39
+
40
+ def build_few_shot_prompt(
41
+ spiritlm_model: Spiritlm,
42
+ input_output: str,
43
+ n_shots: int = 3,
44
+ ) -> str:
45
+ """
46
+ Build the few-shot prompt by simply concatenating a set of examples.
47
+
48
+ E.g., a 3-shots T->S prompt would like this:
49
+ "[Text]text1[Speech]speech_tokens1\n[Text]text2[Speech]speech_tokens2\n[Text]text3[Speech]speech_tokens3\n"
50
+ """
51
+ manifset_file_mapping = {
52
+ "text_text": "t2t",
53
+ "speech_text": "s2t",
54
+ "text_speech": "t2s",
55
+ "speech_speech": "s2s",
56
+ }
57
+ manifest_path = (
58
+ FEW_SHOT_MANIFEST_DIR / f"{manifset_file_mapping[input_output]}.jsonl"
59
+ )
60
+ df = pd.read_json(manifest_path, lines=True)
61
+ assert n_shots <= len(df)
62
+
63
+ # ensure a balanced sampels for each sentiment
64
+ nb_samples_per_sentiment = math.ceil(n_shots / 3)
65
+ df = df.groupby("sentiment").sample(n=nb_samples_per_sentiment)
66
+
67
+ prompts = []
68
+ for _, row in df.iterrows():
69
+ prompt = row["prompt"]
70
+ generation = row["generation"]
71
+ if input_output == "text_text":
72
+ prompt = FEW_SHOT_TEMPLATE.format(
73
+ prompt=text_prompt(spiritlm_model, prompt),
74
+ generation=text_prompt(spiritlm_model, generation),
75
+ )
76
+ elif input_output == "text_speech":
77
+ prompt = FEW_SHOT_TEMPLATE.format(
78
+ prompt=text_prompt(spiritlm_model, prompt),
79
+ generation=wav_prompt(
80
+ spiritlm_model, _load_half_wav(generation, load_first_half=False)
81
+ ),
82
+ )
83
+ elif input_output == "speech_text":
84
+ prompt = FEW_SHOT_TEMPLATE.format(
85
+ prompt=wav_prompt(
86
+ spiritlm_model, _load_half_wav(prompt, load_first_half=True)
87
+ ),
88
+ generation=text_prompt(spiritlm_model, generation),
89
+ )
90
+ elif input_output == "speech_speech":
91
+ prompt = FEW_SHOT_TEMPLATE.format(
92
+ prompt=wav_prompt(
93
+ spiritlm_model, _load_half_wav(prompt, load_first_half=True)
94
+ ),
95
+ generation=wav_prompt(
96
+ spiritlm_model, _load_half_wav(generation, load_first_half=False)
97
+ ),
98
+ )
99
+ prompts.append(prompt)
100
+ print(f"prompts: {prompts}")
101
+ return "\n".join(prompts) + "\n"
spiritlm/eval/stsp/predict_stsp.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Usage example:
9
+
10
+ cd {SPIRITLM ROOT FOLDER}
11
+ export PYTHONPATH=.
12
+
13
+ # Speech to Text
14
+ torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_s_t.jsonl --input_output speech_text
15
+ # Text to Text
16
+ torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_t_t.jsonl --input_output text_text
17
+ # Text to Speech#
18
+ torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred._t_s.jsonl --input_output text_speech
19
+ # Speech to Speech
20
+ torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_s_s.jsonl --input_output speech_speech
21
+
22
+ """
23
+
24
+ import argparse
25
+ import json
26
+ import os
27
+ import uuid
28
+ from pathlib import Path
29
+ from typing import Union
30
+
31
+ import torch
32
+ import torch.distributed as dist
33
+ import torchaudio
34
+ from spiritlm.eval.eval_stsp import eval
35
+ from spiritlm.eval.load_data import SpeechData, TextData
36
+ from spiritlm.eval.stsp.few_shot_prompt import build_few_shot_prompt
37
+ from spiritlm.eval.stsp.sentiment_classifiers import (
38
+ get_text_sentiment_prediction,
39
+ load_sentiment_classifier,
40
+ )
41
+ from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MODEL_ROOT
42
+ from spiritlm.eval.stsp.utils import (
43
+ ExpressoEmotionClassifier,
44
+ load_emotion_classifier,
45
+ wav2emotion_and_sentiment,
46
+ )
47
+ from spiritlm.model.spiritlm_model import (
48
+ ContentType,
49
+ GenerationInput,
50
+ InterleavedOutputs,
51
+ OutputModality,
52
+ Spiritlm,
53
+ )
54
+ from torch.utils.data.distributed import DistributedSampler
55
+ from tqdm import tqdm
56
+ from transformers import AutoModelForSequenceClassification, GenerationConfig, set_seed
57
+
58
+ SPEECH_CLASSIFIER = STSP_MODEL_ROOT / "speech_classifier"
59
+ TEXT_CLASSIFIER = STSP_MODEL_ROOT / "text_classifier"
60
+
61
+ NB_RETRIES = 3
62
+
63
+
64
+ def get_eval_classifier(args):
65
+ if args.input_output.endswith("speech"):
66
+ return load_emotion_classifier(str(SPEECH_CLASSIFIER))
67
+ elif args.input_output.endswith("text"):
68
+ return load_sentiment_classifier(str(TEXT_CLASSIFIER))
69
+ else:
70
+ raise (Exception(f"{args.input_output} not supported"))
71
+
72
+
73
+ def get_sentiment(
74
+ input_output,
75
+ generation,
76
+ classifer: Union[AutoModelForSequenceClassification, ExpressoEmotionClassifier],
77
+ ):
78
+ if input_output.endswith("speech"):
79
+ _, pred_sentiment = wav2emotion_and_sentiment(generation, classifer)
80
+ elif input_output.endswith("text"):
81
+ _, pred_sentiment = get_text_sentiment_prediction(generation, classifer)
82
+ return pred_sentiment
83
+
84
+
85
+ def write_jsonl(dir: str, predictions: dict):
86
+ Path(dir).parent.mkdir(exist_ok=True, parents=True)
87
+ with open(dir, "w") as f:
88
+ for id, result_dict in predictions.items():
89
+ record = {"id": id, **result_dict}
90
+ json_string = json.dumps(record)
91
+ f.write(json_string + "\n") # Add a newline to separate JSON objects
92
+ print(f"{dir} written")
93
+
94
+
95
+ def write_wav(
96
+ wav,
97
+ save_dir: Path,
98
+ sample_rate: int = 16_000,
99
+ ) -> str:
100
+ """Save wav under `save_dir` with a random name and return the full path."""
101
+ save_dir.mkdir(exist_ok=True, parents=True)
102
+ random_path = save_dir / (str(uuid.uuid4()) + ".wav")
103
+ torchaudio.save(
104
+ random_path, torch.from_numpy(wav).unsqueeze(0), sample_rate=sample_rate
105
+ )
106
+ return str(random_path)
107
+
108
+
109
+ def run(args):
110
+ world_size = int(os.environ["WORLD_SIZE"])
111
+ world_rank = int(os.environ["RANK"])
112
+ print(
113
+ f"Running distributed inference with world_size: {world_size}, world_rank: {world_rank}"
114
+ )
115
+ dist.init_process_group("nccl", rank=world_rank, world_size=world_size)
116
+ set_seed(args.seed)
117
+ spiritlm_model = Spiritlm(args.model)
118
+ evaluation_classifier = get_eval_classifier(args)
119
+ input_output = args.input_output
120
+ eval_manifest_path = args.eval_manifest_path
121
+ write_wav_output = args.write_wav_output
122
+
123
+ if args.few_shot > 0:
124
+ prompt = build_few_shot_prompt(
125
+ spiritlm_model=spiritlm_model,
126
+ input_output=args.input_output,
127
+ n_shots=args.few_shot,
128
+ )
129
+ else:
130
+ prompt = None
131
+
132
+ # load
133
+ if input_output.startswith("speech"):
134
+ eval_dataset = SpeechData(eval_manifest_path, root_dir=STSP_DATA_ROOT)
135
+ elif input_output.startswith("text"):
136
+ eval_dataset = TextData(eval_manifest_path, root_dir=STSP_DATA_ROOT)
137
+
138
+ sampler = DistributedSampler(dataset=eval_dataset)
139
+ loader = torch.utils.data.DataLoader(
140
+ dataset=eval_dataset,
141
+ batch_size=1, # large batch size is not supported yet
142
+ sampler=sampler,
143
+ num_workers=4,
144
+ )
145
+ predictions = {}
146
+ if input_output.endswith("speech"):
147
+ output_modality = OutputModality.SPEECH
148
+ max_new_tokens = 300
149
+ else:
150
+ output_modality = OutputModality.TEXT
151
+ max_new_tokens = 50
152
+ for _, data in tqdm(
153
+ enumerate(loader),
154
+ desc=f"Predict {eval_manifest_path}",
155
+ total=eval_dataset.__len__() // world_size,
156
+ ):
157
+ # retry the generation multiple times because sometime it does not generate hubert tokens
158
+ for i in range(NB_RETRIES):
159
+ try:
160
+ out: InterleavedOutputs = spiritlm_model.generate(
161
+ output_modality=output_modality,
162
+ interleaved_inputs=[
163
+ GenerationInput(
164
+ content=(
165
+ data["wav"][0]
166
+ if input_output.startswith("speech")
167
+ else data["text"][0]
168
+ ), # 0 because of batch size 1
169
+ content_type=(
170
+ ContentType.SPEECH
171
+ if input_output.startswith("speech")
172
+ else ContentType.TEXT
173
+ ),
174
+ )
175
+ ],
176
+ generation_config=GenerationConfig(
177
+ temperature=0.8,
178
+ top_p=0.95,
179
+ max_new_tokens=max_new_tokens,
180
+ do_sample=True,
181
+ ),
182
+ prompt=prompt,
183
+ )
184
+ except Exception as e:
185
+ print(f"Got an exception when generating: {e}")
186
+ if i == NB_RETRIES - 1:
187
+ raise Exception(f"Failed to generate after {NB_RETRIES}")
188
+ else:
189
+ break
190
+ assert len(out) == 1
191
+ generated_output = out[0].content
192
+ detected_sentiment = get_sentiment(
193
+ input_output, generated_output, evaluation_classifier
194
+ )
195
+ if output_modality == OutputModality.TEXT:
196
+ generation = generated_output
197
+ elif write_wav_output and output_modality == OutputModality.SPEECH:
198
+ generation = write_wav(generated_output, Path(write_wav_output))
199
+ else:
200
+ generation = None
201
+ result_dict = {"pred": detected_sentiment}
202
+ if generation is not None:
203
+ result_dict["generation"] = generation
204
+ predictions[str(data["id"][0])] = result_dict
205
+
206
+ if args.eval:
207
+ gathered_predictions = [None for _ in range(world_size)]
208
+ dist.gather_object(
209
+ predictions, gathered_predictions if world_rank == 0 else None, dst=0
210
+ )
211
+ if world_rank == 0:
212
+ all_predictions = {k: v for d in gathered_predictions for k, v in d.items()}
213
+ eval(
214
+ eval_manifest_path,
215
+ {k: v["pred"] for k, v in all_predictions.items()},
216
+ info_data=f"{eval_manifest_path}, input-output {input_output}",
217
+ label="sentiment",
218
+ )
219
+
220
+ if args.write_pred is not None and world_rank == 0:
221
+ write_jsonl(args.write_pred, all_predictions)
222
+
223
+
224
+ def setup_env():
225
+ os.environ["OMP_NUM_THREADS"] = "1"
226
+
227
+
228
+ if __name__ == "__main__":
229
+ parser = argparse.ArgumentParser()
230
+ parser.add_argument(
231
+ "--eval_manifest_path", # data/examples/ref.jsonl
232
+ type=str,
233
+ help="Path to reference record",
234
+ required=True,
235
+ )
236
+
237
+ parser.add_argument(
238
+ "--data_root_dir", # data/stsp_data
239
+ type=str,
240
+ help=f"Path to root data folder, default to {str(STSP_DATA_ROOT)}",
241
+ default=str(STSP_DATA_ROOT),
242
+ required=False,
243
+ )
244
+
245
+ parser.add_argument(
246
+ "--model",
247
+ type=str,
248
+ default="spirit-lm-expressive-7b",
249
+ help="Model name (spirit-lm-base-7b or spirit-lm-expressive-7b) or path to model",
250
+ required=False,
251
+ )
252
+ parser.add_argument(
253
+ "--few_shot",
254
+ type=int,
255
+ default=0,
256
+ help="Number of few shot examples, 3/6/9",
257
+ required=False,
258
+ )
259
+ parser.add_argument(
260
+ "--input_output",
261
+ type=str,
262
+ default="speech_speech",
263
+ help="speech_speech speech_text text_speech text_text",
264
+ required=False,
265
+ )
266
+ parser.add_argument(
267
+ "--eval_type",
268
+ type=str,
269
+ default="emotion",
270
+ required=False,
271
+ )
272
+ parser.add_argument(
273
+ "--write_pred",
274
+ type=str,
275
+ default=None,
276
+ help="Path to save the predictions output",
277
+ required=False,
278
+ )
279
+ parser.add_argument(
280
+ "--write_wav_output",
281
+ type=str,
282
+ default=None,
283
+ help="Path to save the generated audio if the output is speech",
284
+ required=False,
285
+ )
286
+ parser.add_argument(
287
+ "--eval",
288
+ default=False,
289
+ action="store_true",
290
+ )
291
+ parser.add_argument(
292
+ "--seed",
293
+ default=0,
294
+ type=int,
295
+ )
296
+
297
+ args = parser.parse_args()
298
+ setup_env()
299
+ run(args)
spiritlm/eval/stsp/sanity_check_download.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+
9
+ from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MANIFEST_ROOT
10
+
11
+
12
+ def check_all_datasets():
13
+ for dataset_manifset in STSP_MANIFEST_ROOT.glob("**/*jsonl"):
14
+ records_checked = 0
15
+ print(f"dataset_manifset: {dataset_manifset}")
16
+ with dataset_manifset.open() as f:
17
+ for record in f:
18
+ record = json.loads(record)
19
+ for wav_key in ["wav_path", "prompt", "generation"]:
20
+ if wav_key in record and record[wav_key].endswith(".wav"):
21
+ wav_path = STSP_DATA_ROOT / record[wav_key]
22
+ assert (
23
+ wav_path.is_file()
24
+ ), f"Record {record[wav_key]} not found in {str(wav_path)} and listed in {dataset_manifset}"
25
+ records_checked += 1
26
+ print(f"{records_checked} records checked for {dataset_manifset.stem} split")
27
+
28
+
29
+ if __name__ == "__main__":
30
+ check_all_datasets()
spiritlm/eval/stsp/sentiment_classifiers.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Dict, List, Tuple
8
+
9
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
10
+
11
+
12
+ def pred_to_label(
13
+ sentiment_prediction_scores: List[List[Dict[str, Any]]],
14
+ ) -> Tuple[str, float]:
15
+ if isinstance(sentiment_prediction_scores[0], list):
16
+ sentiment_prediction_scores = sentiment_prediction_scores[0]
17
+ item_with_max_score = max(
18
+ sentiment_prediction_scores, key=lambda _dict: _dict["score"]
19
+ )
20
+ score = item_with_max_score["score"]
21
+ return score, item_with_max_score["label"].lower()
22
+
23
+
24
+ def get_text_sentiment_prediction(text: str, sentiment_classifier) -> Tuple[str, float]:
25
+ return pred_to_label(sentiment_classifier(text))
26
+
27
+
28
+ def load_sentiment_classifier(model_dir: str):
29
+ classifier = pipeline(
30
+ task="text-classification",
31
+ model=AutoModelForSequenceClassification.from_pretrained(model_dir),
32
+ tokenizer=AutoTokenizer.from_pretrained(
33
+ "j-hartmann/sentiment-roberta-large-english-3-classes"
34
+ ),
35
+ top_k=None,
36
+ )
37
+ return classifier
spiritlm/eval/stsp/stsp_constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ from pathlib import Path
8
+
9
+ STSP_ROOT = Path(__file__).parents[3] / "data" / "stsp_data"
10
+ STSP_DATA_ROOT = STSP_ROOT / "data"
11
+ STSP_MODEL_ROOT = STSP_ROOT / "model"
12
+ STSP_MANIFEST_ROOT = STSP_ROOT / "manifest"
spiritlm/eval/stsp/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from functools import cache
9
+ from typing import List, Optional, Tuple
10
+
11
+ import torch
12
+ import torchaudio
13
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
14
+
15
+ EXPRESSO_EMOTION_2_SENTIMENT = {
16
+ "happy": "positive",
17
+ "angry": "negative",
18
+ "sad": "negative",
19
+ "default": "neutral",
20
+ }
21
+
22
+ EMOTION_2_SENTIMENT = {
23
+ "happy": "positive",
24
+ "angry": "negative",
25
+ "sad": "negative",
26
+ "default": "neutral",
27
+ "neutral": "neutral",
28
+ "amused": "positive",
29
+ }
30
+
31
+
32
+ @cache
33
+ def emotions2new_label_names_and_indices(
34
+ emotions_to_select: Tuple[str],
35
+ label_names: Tuple[str],
36
+ ) -> Tuple[List[str], List[int]]:
37
+ emotion2index = {e: i for i, e in enumerate(label_names)}
38
+ sorted_indices_emotions = sorted(
39
+ [(emotion2index[emotion], emotion) for emotion in emotions_to_select]
40
+ )
41
+ zipped = list(zip(*sorted_indices_emotions))
42
+ return zipped
43
+
44
+
45
+ def expresso_emotion2_sentiment(emotion: str):
46
+ return EXPRESSO_EMOTION_2_SENTIMENT[emotion]
47
+
48
+
49
+ @dataclass
50
+ class ExpressoEmotionClassifier:
51
+ feature_extractor: AutoFeatureExtractor
52
+ model: AutoModelForAudioClassification
53
+ label_names: List[str]
54
+
55
+
56
+ def load_emotion_classifier(checkpoint_path: str) -> ExpressoEmotionClassifier:
57
+ feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint_path)
58
+ model = (
59
+ AutoModelForAudioClassification.from_pretrained(checkpoint_path).cuda().eval()
60
+ )
61
+ label_names = [model.config.id2label[i] for i in range(model.config.num_labels)]
62
+ print(f"Classification model loaded from {checkpoint_path} !")
63
+ return ExpressoEmotionClassifier(feature_extractor, model, label_names)
64
+
65
+
66
+ @torch.inference_mode()
67
+ def predict_audio(
68
+ audio,
69
+ expresso_emotion_classifier: ExpressoEmotionClassifier,
70
+ emotions_to_predict: Optional[List[str]] = None,
71
+ ):
72
+ if isinstance(audio, str):
73
+ speech, _ = torchaudio.load(audio)
74
+ resampler = torchaudio.transforms.Resample(
75
+ expresso_emotion_classifier.feature_extractor.sampling_rate
76
+ )
77
+ speech = resampler(speech).squeeze().numpy()
78
+ else:
79
+ speech = audio
80
+
81
+ features = expresso_emotion_classifier.feature_extractor(
82
+ speech,
83
+ sampling_rate=expresso_emotion_classifier.feature_extractor.sampling_rate,
84
+ return_tensors="pt",
85
+ )
86
+ features["input_values"] = features["input_values"].cuda()
87
+
88
+ logits = expresso_emotion_classifier.model(**features).logits
89
+ if emotions_to_predict is not None:
90
+ (indices, label_names) = emotions2new_label_names_and_indices(
91
+ tuple(emotions_to_predict), tuple(expresso_emotion_classifier.label_names)
92
+ )
93
+ logits = logits[:, indices]
94
+ else:
95
+ label_names = expresso_emotion_classifier.label_names
96
+ pred_id = torch.argmax(logits, dim=-1)[0].item()
97
+
98
+ return label_names[pred_id], logits.detach().cpu().numpy()
99
+
100
+
101
+ def wav2emotion(
102
+ wav,
103
+ expresso_emotion_classifier: ExpressoEmotionClassifier,
104
+ emotions_to_predict: Optional[List[str]] = None,
105
+ ) -> str:
106
+ label_logits = predict_audio(
107
+ audio=wav,
108
+ expresso_emotion_classifier=expresso_emotion_classifier,
109
+ emotions_to_predict=emotions_to_predict,
110
+ )
111
+ pred_emotion = label_logits[0]
112
+ return pred_emotion
113
+
114
+
115
+ def wav2emotion_and_sentiment(
116
+ wav,
117
+ expresso_emotion_classifier: ExpressoEmotionClassifier,
118
+ emotions_to_predict: Optional[List[str]] = None,
119
+ ) -> Tuple[str, str]:
120
+ pred_emotion = wav2emotion(wav, expresso_emotion_classifier, emotions_to_predict)
121
+ mapped_sentiment = expresso_emotion2_sentiment(pred_emotion)
122
+ return pred_emotion, mapped_sentiment
spiritlm/eval/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the FAIR Noncommercial Research License
5
+ # found in the LICENSE file in the root directory of this source tree.
6
+
7
+ import torchaudio
8
+ from spiritlm.model.spiritlm_model import Spiritlm
9
+
10
+
11
+ def wav_prompt(spiritlm_model: Spiritlm, wav_path: str) -> str:
12
+ wav = torchaudio.load(wav_path)[0].squeeze(0)
13
+ return spiritlm_model.SPEECH_PROMPT_PREFIX + spiritlm_model.speech_tokenizer(wav)
14
+
15
+
16
+ def text_prompt(spiritlm_model: Spiritlm, text: str) -> str:
17
+ return spiritlm_model.TEXT_PROMPT_PREFIX + text
spiritlm/model/README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model for Spirit LM
2
+ This repo includes the Spirit LM model wrapper.
3
+
4
+ ## Usage examples
5
+
6
+ ### Model Loading
7
+ ```python
8
+ from spiritlm.model.spiritlm_model import Spiritlm
9
+
10
+ # Spirit LM Base 7B
11
+ spirit_lm = Spiritlm("spirit-lm-base-7b")
12
+
13
+ # Spirit LM Expressive 7B
14
+ spirit_lm = Spiritlm("spirit-lm-expressive-7b")
15
+ ```
16
+
17
+ ### Generation examples
18
+ ```python
19
+ from spiritlm.model.spiritlm_model import OutputModality, GenerationInput, ContentType
20
+ from transformers import GenerationConfig
21
+
22
+ # Generate only text
23
+ spirit_lm.generate(
24
+ output_modality=OutputModality.TEXT,
25
+ interleaved_inputs=[
26
+ GenerationInput(
27
+ content="The largest country in the world is",
28
+ content_type=ContentType.TEXT,
29
+ )
30
+ ],
31
+ generation_config=GenerationConfig(
32
+ temperature=0.9,
33
+ top_p=0.95,
34
+ max_new_tokens=50,
35
+ do_sample=True,
36
+ ),
37
+ )
38
+
39
+ # Expected output format:
40
+ # [GenerationOuput(content='Russia, with an area of ...', content_type=<ContentType.TEXT: 'TEXT'>)]
41
+
42
+ # Generate only speech
43
+ spirit_lm.generate(
44
+ output_modality=OutputModality.SPEECH,
45
+ interleaved_inputs=[
46
+ GenerationInput(
47
+ content="examples/audio/7143-88743-0029.flac",
48
+ content_type=ContentType.SPEECH,
49
+ )
50
+ ],
51
+ generation_config=GenerationConfig(
52
+ temperature=0.9,
53
+ top_p=0.95,
54
+ max_new_tokens=200,
55
+ do_sample=True,
56
+ ),
57
+ )
58
+
59
+ # Expected output format:
60
+ # [GenerationOuput(content=array([ 3.6673620e-05, 2.6468514e-04, 1.0735081e-03, ...,], dtype=float32), content_type=<ContentType.SPEECH: 'SPEECH'>)]
61
+
62
+
63
+ # Arbitrary generation
64
+ spirit_lm.generate(
65
+ output_modality=OutputModality.ARBITRARY,
66
+ interleaved_inputs=[
67
+ GenerationInput(
68
+ content="examples/audio/7143-88743-0029.flac",
69
+ content_type=ContentType.SPEECH,
70
+ )
71
+ ],
72
+ generation_config=GenerationConfig(
73
+ temperature=0.9,
74
+ top_p=0.95,
75
+ max_new_tokens=200,
76
+ do_sample=True,
77
+ ),
78
+ )
79
+ # Expected output format is a list of GenerationOuput where content type could be `ContentType.TEXT' or `ContentType.SPEECH`:
80
+ # [GenerationOuput(content='xxx', content_type=<ContentType.TEXT: 'TEXT'>), GenerationOuput(content=array([ 0.00553902, -0.03210586, ... ], dtype=float32), content_type=<ContentType.SPEECH: 'SPEECH'>), GenerationOuput(content='yyy', content_type=<ContentType.TEXT: 'TEXT'>), GenerationOuput(content=array([0.04051103, 0.03596291, 0.03381396, ..., 0.05103811, 0.05429034, ..,,], dtype=float32), content_type=<ContentType.SPEECH: 'SPEECH'>)]
81
+ ```
82
+ See more examples with other modalites in [examples/speech_generation/spirit_model.ipynb](../../examples/speech_generation/spirit_model.ipynb).