Commit
·
43acf85
1
Parent(s):
2955de7
added codes and tokenizers
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- LICENSE +116 -0
- MODEL_CARD.md +81 -0
- README.md +55 -0
- assets/spiritlm_overview.png +0 -0
- checkpoints/README.md +52 -0
- checkpoints/speech_tokenizer/hifigan_spiritlm_base/config.json +55 -0
- checkpoints/speech_tokenizer/hifigan_spiritlm_base/generator.pt +3 -0
- checkpoints/speech_tokenizer/hifigan_spiritlm_base/speakers.txt +4 -0
- checkpoints/speech_tokenizer/hifigan_spiritlm_base/styles.txt +34 -0
- checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/config.json +60 -0
- checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/generator.pt +3 -0
- checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/speakers.txt +4 -0
- checkpoints/speech_tokenizer/hubert_25hz/L11_quantizer_500.pt +3 -0
- checkpoints/speech_tokenizer/hubert_25hz/mhubert_base_25hz.pt +3 -0
- checkpoints/speech_tokenizer/style_encoder_w2v2/config.json +321 -0
- checkpoints/speech_tokenizer/style_encoder_w2v2/pytorch_model.bin +3 -0
- checkpoints/speech_tokenizer/vqvae_f0_quantizer/config.yaml +59 -0
- checkpoints/speech_tokenizer/vqvae_f0_quantizer/model.pt +3 -0
- data/examples/pred.jsonl +5 -0
- data/examples/ref.jsonl +5 -0
- env.yml +19 -0
- examples/audio/7143-88743-0029.flac +0 -0
- examples/distributed_inference_recipe/multi_nodes.slurm +24 -0
- examples/distributed_inference_recipe/run_dist.py +89 -0
- examples/speech_generation/spirit_model.ipynb +0 -0
- examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb +0 -0
- requirements.dev.txt +1 -0
- requirements.txt +9 -0
- setup.py +60 -0
- spiritlm.egg-info/PKG-INFO +84 -0
- spiritlm.egg-info/SOURCES.txt +32 -0
- spiritlm.egg-info/dependency_links.txt +1 -0
- spiritlm.egg-info/not-zip-safe +1 -0
- spiritlm.egg-info/requires.txt +15 -0
- spiritlm.egg-info/top_level.txt +2 -0
- spiritlm/__init__.py +5 -0
- spiritlm/__pycache__/__init__.cpython-310.pyc +0 -0
- spiritlm/eval/README.md +92 -0
- spiritlm/eval/eval_stsp.py +87 -0
- spiritlm/eval/load_data.py +50 -0
- spiritlm/eval/stsp/few_shot_prompt.py +101 -0
- spiritlm/eval/stsp/predict_stsp.py +299 -0
- spiritlm/eval/stsp/sanity_check_download.py +30 -0
- spiritlm/eval/stsp/sentiment_classifiers.py +37 -0
- spiritlm/eval/stsp/stsp_constants.py +12 -0
- spiritlm/eval/stsp/utils.py +122 -0
- spiritlm/eval/utils.py +17 -0
- spiritlm/model/README.md +82 -0
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <[email protected]>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to spiritlm
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
26 |
+
disclosure of security bugs. In those cases, please go through the process
|
27 |
+
outlined on that page and do not file a public issue.
|
28 |
+
|
29 |
+
## License
|
30 |
+
By contributing to spiritlm, you agree that your contributions will be licensed
|
31 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FAIR Noncommercial Research License
|
2 |
+
Last Updated: October 18, 2024
|
3 |
+
|
4 |
+
“Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
|
5 |
+
|
6 |
+
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
|
7 |
+
|
8 |
+
“Documentation” means the specifications, manuals and documentation accompanying
|
9 |
+
Research Materials distributed by Meta.
|
10 |
+
|
11 |
+
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
12 |
+
|
13 |
+
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
14 |
+
|
15 |
+
“Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
|
16 |
+
|
17 |
+
“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
18 |
+
|
19 |
+
By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
|
20 |
+
|
21 |
+
|
22 |
+
1. License Rights and Redistribution.
|
23 |
+
|
24 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
|
25 |
+
|
26 |
+
b. Redistribution and Use.
|
27 |
+
i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
|
28 |
+
|
29 |
+
ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
|
30 |
+
|
31 |
+
iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
|
32 |
+
|
33 |
+
iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
|
34 |
+
|
35 |
+
2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
36 |
+
|
37 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
|
38 |
+
|
39 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
40 |
+
|
41 |
+
5. Intellectual Property.
|
42 |
+
|
43 |
+
a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
44 |
+
|
45 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
|
46 |
+
|
47 |
+
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
|
48 |
+
|
49 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
50 |
+
|
51 |
+
8. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at [https://github.com/facebookresearch/spiritlm/blob/main/LICENSE]; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
52 |
+
|
53 |
+
|
54 |
+
FAIR Acceptable Use Policy
|
55 |
+
|
56 |
+
The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
|
57 |
+
|
58 |
+
As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials.
|
59 |
+
|
60 |
+
|
61 |
+
Prohibited Uses
|
62 |
+
|
63 |
+
You agree you will not use, or allow others to use, Research Materials to:
|
64 |
+
|
65 |
+
1. Violate the law or others’ rights, including to:
|
66 |
+
a. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
|
67 |
+
i. Violence or terrorism
|
68 |
+
ii. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
|
69 |
+
iii. Human trafficking, exploitation, and sexual violence
|
70 |
+
iv. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
|
71 |
+
v. Sexual solicitation
|
72 |
+
iv. Any other criminal activity
|
73 |
+
|
74 |
+
b. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
|
75 |
+
|
76 |
+
c. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
|
77 |
+
|
78 |
+
d. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
|
79 |
+
|
80 |
+
e. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
|
81 |
+
|
82 |
+
f. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials
|
83 |
+
|
84 |
+
g. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
|
85 |
+
|
86 |
+
2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
|
87 |
+
|
88 |
+
a. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
|
89 |
+
|
90 |
+
b. Guns and illegal weapons (including weapon development)
|
91 |
+
|
92 |
+
c. Illegal drugs and regulated/controlled substances
|
93 |
+
|
94 |
+
d. Operation of critical infrastructure, transportation technologies, or heavy machinery
|
95 |
+
|
96 |
+
e. Self-harm or harm to others, including suicide, cutting, and eating disorders
|
97 |
+
|
98 |
+
f. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
|
99 |
+
|
100 |
+
3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following:
|
101 |
+
|
102 |
+
a. Generating, promoting, or furthering fraud or the creation or promotion of disinformation
|
103 |
+
|
104 |
+
b. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
|
105 |
+
|
106 |
+
c. Generating, promoting, or further distributing spam
|
107 |
+
|
108 |
+
d. Impersonating another individual without consent, authorization, or legal right
|
109 |
+
|
110 |
+
e. Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials o are human-generated
|
111 |
+
|
112 |
+
f. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
|
113 |
+
|
114 |
+
4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
|
115 |
+
|
116 |
+
Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform].
|
MODEL_CARD.md
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Meta Spirit LM Model Card
|
2 |
+
|
3 |
+
## Model Details
|
4 |
+
|
5 |
+
*Note: Use of this model is governed by the FAIR Noncommercial Research License.*
|
6 |
+
|
7 |
+
Spirit LM is a multimodal language model that freely mixes text and speech. The model can be prompted with either text or speech and is capable of generating outputs in either modality, while preserving the expressivity of the input prompt. The model is also able to learn new tasks across modalities such as automatic speech recognition, text-to-speech, and speech classification in a few-shot manner.
|
8 |
+
|
9 |
+
## Model Developers
|
10 |
+
Meta
|
11 |
+
|
12 |
+
## Variations
|
13 |
+
Spirit LM comes in two versions: Spirit LM Base that uses speech phonetic tokens and Spirit LM Expressive that models expressivity using pitch and style tokens in addition to the phonetic tokens.
|
14 |
+
|
15 |
+
## Input
|
16 |
+
Models input text or speech or a mixed sequence of the two.
|
17 |
+
|
18 |
+
## Output
|
19 |
+
Models generate text or speech or a mixed sequence of the two.
|
20 |
+
|
21 |
+
## Model Architecture
|
22 |
+
### Speech Tokenizer
|
23 |
+
Spirit LM uses 3 types of speech tokenizers: Phonetic Tokenizer (HuBERT), Pitch Tokenizer (VQ-VAE) and Style Tokenizer (Speechprop or Wav2vec2). We use Hifi-GAN to convert the speech tokens back to audio.
|
24 |
+
|
25 |
+
It is worth noting that in the associated paper, for Spirit LM Expressive, we used Speechprop to extract style tokens, while we use a Wav2vec2 model to extract style tokens in this release.
|
26 |
+
|
27 |
+
| | Model | Parameters | Input | Output |
|
28 |
+
|------------------------|--------------------------|------------|---------------------|--------------------|
|
29 |
+
| Phonetic Tokenizer | HuBERT+LinearQuantizer | 96M | Waveform | Phonetic Tokens |
|
30 |
+
| Pitch Tokenizer | VQ-VAE | 0.2M | Extracted F0 | Pitch Tokens |
|
31 |
+
| Style Tokenizer | Wav2vec2+LinearProjection| 95M | Waveform | Style Tokens |
|
32 |
+
| Base Speech Decoder | Hifi-GAN | 14M | Phonetic Tokens | Waveform |
|
33 |
+
| Expressive Speech Decoder | Hifi-GAN | 15M | Phonetic, Pitch, Style Tokens | Waveform
|
34 |
+
|
35 |
+
### Language Model
|
36 |
+
Spirit LM is initialized from the Llama-2 7B model.
|
37 |
+
|
38 |
+
| | Architecture | Parameters | Input/Output Tokens | Vocab Size |
|
39 |
+
|----------------------|----------------|------------|----------------------------------------------------------|------------|
|
40 |
+
| Spirit LM Base | Llama-2 7B | 7B | Text Tokens, Phonetic Tokens | 32512 |
|
41 |
+
| Spirit LM Expressive | Llama-2 7B | 7B | Text Tokens, Phonetic Tokens, Pitch Tokens, Style Tokens | 32768 |
|
42 |
+
|
43 |
+
### Release Date
|
44 |
+
The models were trained between October and December 2023. The research paper was released on February 8th 2024. We released the model on October 18th 2024.
|
45 |
+
|
46 |
+
### Status
|
47 |
+
This is a static model trained on an offline dataset.
|
48 |
+
|
49 |
+
### License
|
50 |
+
We release the model under the FAIR Noncommercial Research License found in the [LICENSE](LICENSE) file in the root directory of this repo.
|
51 |
+
|
52 |
+
### Research Paper
|
53 |
+
More information can be found in the paper ["SpiRit-LM: Interleaved Spoken and Written Language Model"](https://arxiv.org/pdf/2402.05755.pdf).
|
54 |
+
|
55 |
+
## Hardware and Software
|
56 |
+
### Training Factors
|
57 |
+
We used custom training libraries. The training of the released models has been performed on Meta’s Research Clusters.
|
58 |
+
|
59 |
+
The training of each model (Spirit LM Base and Spirit LM Expressive) takes 21K GPU hours of computation on hardware of type A100-80GB (TDP of 350-400W), not including the training of Llama-2.
|
60 |
+
|
61 |
+
## Training Data
|
62 |
+
We trained the models on a combination of text-only datasets, speech-only datasets and aligned speech-text datasets. All the speech datasets are publicly available. Here are the statistics of the datasets we used:
|
63 |
+
|
64 |
+
| | Hours | Speech Tokens | Text Tokens |
|
65 |
+
|--------------|-------|---------------|-------------|
|
66 |
+
| Speech-only | 458K | 28.2B | - |
|
67 |
+
| Speech+Text | 111K | 7.0B | 1.4B |
|
68 |
+
| Text-only | - | - | 307B |
|
69 |
+
|
70 |
+
## Evaluation Results
|
71 |
+
See evaluations for our models and detailed ablations in Section 4 and 5, and safety evaluations in Section 6 of the [research paper](https://arxiv.org/pdf/2402.05755.pdf).
|
72 |
+
|
73 |
+
## Intended Use
|
74 |
+
### Intended Use Cases
|
75 |
+
Spirit LM is intended for noncommercial research use in English.
|
76 |
+
|
77 |
+
### Out-of-Scope Uses
|
78 |
+
Use in any manner that violates applicable laws or regulations (including trade compliance laws). Use in languages other than English. Use in any other way that is prohibited by the FAIR Noncommercial Research License and Acceptable Use Policy.
|
79 |
+
|
80 |
+
## Ethical Considerations and Limitations
|
81 |
+
This model is built on Llama 2 which carries risks with use. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, Llama 2’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. The model’s speech capabilities are designed to analyze speaker agnostic qualities of any input speech and output speech in one of four pre-set voices. The model is meant for use for noncommercial research purposes only and should not be deployed in any consumer-facing applications.
|
README.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Meta Spirit LM: Interleaved Spoken and Written Language Model
|
2 |
+
|
3 |
+
This repository contains the model weights, inference code and evaluation scripts for the Spirit LM [paper](https://arxiv.org/pdf/2402.05755.pdf). You can find more generation samples on our [demo page](https://speechbot.github.io/spiritlm/).
|
4 |
+
|
5 |
+
## Spirit LM Model Overview
|
6 |
+
<img src="assets/spiritlm_overview.png">
|
7 |
+
|
8 |
+
## Installation Setup
|
9 |
+
### Conda
|
10 |
+
```
|
11 |
+
conda env create -f env.yml
|
12 |
+
pip install -e '.[eval]'
|
13 |
+
|
14 |
+
```
|
15 |
+
### Pip
|
16 |
+
```
|
17 |
+
pip install -e '.[eval]'
|
18 |
+
```
|
19 |
+
|
20 |
+
### Dev
|
21 |
+
(Optionally, use only if you want to run the tests.)
|
22 |
+
```
|
23 |
+
pip install -e '.[dev]'
|
24 |
+
```
|
25 |
+
|
26 |
+
## Checkpoints Setup
|
27 |
+
See [checkpoints/README.md](checkpoints/README.md)
|
28 |
+
|
29 |
+
## Quick Start
|
30 |
+
### Speech Tokenization
|
31 |
+
See [spiritlm/speech_tokenizer/README.md](spiritlm/speech_tokenizer/README.md)
|
32 |
+
### Spirit LM Generation
|
33 |
+
See [spiritlm/model/README.md](spiritlm/model/README.md)
|
34 |
+
### Speech-Text Sentiment Preservation benchmark (STSP)
|
35 |
+
See [spiritlm/eval/README.md](spiritlm/eval/README.md)
|
36 |
+
|
37 |
+
## Model Card
|
38 |
+
More details of the model can be found in [MODEL_CARD.md](MODEL_CARD.md).
|
39 |
+
|
40 |
+
## License
|
41 |
+
The present code is provided under the **FAIR Noncommercial Research License** found in [LICENSE](LICENSE).
|
42 |
+
|
43 |
+
## Citation
|
44 |
+
```
|
45 |
+
@misc{nguyen2024spiritlminterleavedspokenwritten,
|
46 |
+
title={SpiRit-LM: Interleaved Spoken and Written Language Model},
|
47 |
+
author={Tu Anh Nguyen and Benjamin Muller and Bokai Yu and Marta R. Costa-jussa and Maha Elbayad and Sravya Popuri and Paul-Ambroise Duquenne and Robin Algayres and Ruslan Mavlyutov and Itai Gat and Gabriel Synnaeve and Juan Pino and Benoit Sagot and Emmanuel Dupoux},
|
48 |
+
year={2024},
|
49 |
+
eprint={2402.05755},
|
50 |
+
archivePrefix={arXiv},
|
51 |
+
primaryClass={cs.CL},
|
52 |
+
url={https://arxiv.org/abs/2402.05755},
|
53 |
+
}
|
54 |
+
```
|
55 |
+
|
assets/spiritlm_overview.png
ADDED
![]() |
checkpoints/README.md
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Spirit LM Checkpoints
|
2 |
+
|
3 |
+
## Download Checkpoints
|
4 |
+
To access and download Spirit LM Checkpoints, please request the model artifacts in this link:
|
5 |
+
|
6 |
+
[https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/](https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/)
|
7 |
+
|
8 |
+
Upon approval, you will then receive an email with download links to each model artifact.
|
9 |
+
|
10 |
+
Please note that Spirit LM is made available under the **FAIR Noncommercial Research License**
|
11 |
+
found in the [LICENSE](../LICENSE) file in the root directory of this source tree and Acceptable Use Policy.
|
12 |
+
|
13 |
+
## Structure
|
14 |
+
The checkpoints directory should look like this:
|
15 |
+
```
|
16 |
+
checkpoints/
|
17 |
+
├── README.md
|
18 |
+
├── speech_tokenizer
|
19 |
+
│ ├── hifigan_spiritlm_base
|
20 |
+
│ │ ├── config.json
|
21 |
+
│ │ ├── generator.pt
|
22 |
+
│ │ ├── speakers.txt
|
23 |
+
│ │ └── styles.txt
|
24 |
+
│ ├── hifigan_spiritlm_expressive_w2v2
|
25 |
+
│ │ ├── config.json
|
26 |
+
│ │ ├── generator.pt
|
27 |
+
│ │ └── speakers.txt
|
28 |
+
│ ├── hubert_25hz
|
29 |
+
│ │ ├── L11_quantizer_500.pt
|
30 |
+
│ │ └── mhubert_base_25hz.pt
|
31 |
+
│ ├── style_encoder_w2v2
|
32 |
+
│ │ ├── config.json
|
33 |
+
│ │ └── pytorch_model.bin
|
34 |
+
│ └── vqvae_f0_quantizer
|
35 |
+
│ ├── config.yaml
|
36 |
+
│ └── model.pt
|
37 |
+
└── spiritlm_model
|
38 |
+
├── spirit-lm-base-7b
|
39 |
+
│ ├── config.json
|
40 |
+
│ ├── generation_config.json
|
41 |
+
│ ├── pytorch_model.bin
|
42 |
+
│ ├── special_tokens_map.json
|
43 |
+
│ ├── tokenizer_config.json
|
44 |
+
│ └── tokenizer.model
|
45 |
+
└── spirit-lm-expressive-7b
|
46 |
+
├── config.json
|
47 |
+
├── generation_config.json
|
48 |
+
├── pytorch_model.bin
|
49 |
+
├── special_tokens_map.json
|
50 |
+
├── tokenizer_config.json
|
51 |
+
└── tokenizer.model
|
52 |
+
```
|
checkpoints/speech_tokenizer/hifigan_spiritlm_base/config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
|
3 |
+
"resblock": "1",
|
4 |
+
"num_gpus": 0,
|
5 |
+
"batch_size": 16,
|
6 |
+
"learning_rate": 0.0002,
|
7 |
+
"adam_b1": 0.8,
|
8 |
+
"adam_b2": 0.99,
|
9 |
+
"lr_decay": 0.999,
|
10 |
+
"seed": 1234,
|
11 |
+
|
12 |
+
"upsample_rates": [5,4,4,4,2],
|
13 |
+
"upsample_kernel_sizes": [11,8,8,8,4],
|
14 |
+
"upsample_initial_channel": 512,
|
15 |
+
"resblock_kernel_sizes": [3,7,11],
|
16 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
17 |
+
"num_embeddings": 501,
|
18 |
+
"embedding_dim": 128,
|
19 |
+
"model_in_dim": 384,
|
20 |
+
|
21 |
+
"segment_size": 8960,
|
22 |
+
"code_hop_size": 640,
|
23 |
+
"f0": false,
|
24 |
+
"num_mels": 80,
|
25 |
+
"num_freq": 1025,
|
26 |
+
"n_fft": 1024,
|
27 |
+
"hop_size": 256,
|
28 |
+
"win_size": 1024,
|
29 |
+
|
30 |
+
"multispkr": "from_input_file",
|
31 |
+
"num_speakers": 4,
|
32 |
+
"multistyle": "from_input_file",
|
33 |
+
"num_styles": 34,
|
34 |
+
|
35 |
+
"dur_prediction_weight": 1.0,
|
36 |
+
"dur_predictor_params": {
|
37 |
+
"encoder_embed_dim": 128,
|
38 |
+
"var_pred_hidden_dim": 128,
|
39 |
+
"var_pred_kernel_size": 3,
|
40 |
+
"var_pred_dropout": 0.5
|
41 |
+
},
|
42 |
+
|
43 |
+
"sampling_rate": 16000,
|
44 |
+
|
45 |
+
"fmin": 0,
|
46 |
+
"fmax": 8000,
|
47 |
+
"fmax_for_loss": null,
|
48 |
+
|
49 |
+
"num_workers": 4,
|
50 |
+
|
51 |
+
"dist_config": {
|
52 |
+
"dist_backend": "nccl",
|
53 |
+
"dist_url": "env://"
|
54 |
+
}
|
55 |
+
}
|
checkpoints/speech_tokenizer/hifigan_spiritlm_base/generator.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d66c49067aeff93b14b038f143c2dc9ed981671956512d8e702897416f13c459
|
3 |
+
size 57512631
|
checkpoints/speech_tokenizer/hifigan_spiritlm_base/speakers.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ex04
|
2 |
+
ex02
|
3 |
+
ex03
|
4 |
+
ex01
|
checkpoints/speech_tokenizer/hifigan_spiritlm_base/styles.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
read-default
|
2 |
+
read-enunciated
|
3 |
+
read-confused
|
4 |
+
read-laughing
|
5 |
+
read-whisper
|
6 |
+
read-sad
|
7 |
+
read-happy
|
8 |
+
conv-projected
|
9 |
+
conv-default
|
10 |
+
conv-sympathetic
|
11 |
+
conv-fast
|
12 |
+
conv-disgusted
|
13 |
+
conv-laughing
|
14 |
+
conv-calm
|
15 |
+
conv-sarcastic
|
16 |
+
conv-whisper
|
17 |
+
conv-angry
|
18 |
+
conv-sad
|
19 |
+
conv-happy
|
20 |
+
conv-enunciated
|
21 |
+
conv-awe
|
22 |
+
read-singing
|
23 |
+
conv-confused
|
24 |
+
conv-fearful
|
25 |
+
conv-narration
|
26 |
+
conv-sleepy
|
27 |
+
conv-child
|
28 |
+
conv-animal
|
29 |
+
conv-childdir
|
30 |
+
conv-animaldir
|
31 |
+
conv-bored
|
32 |
+
conv-desire
|
33 |
+
conv-nonverbal
|
34 |
+
read-narration
|
checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/config.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
|
3 |
+
"resblock": "1",
|
4 |
+
"num_gpus": 0,
|
5 |
+
"batch_size": 128,
|
6 |
+
"learning_rate": 0.0002,
|
7 |
+
"adam_b1": 0.8,
|
8 |
+
"adam_b2": 0.99,
|
9 |
+
"lr_decay": 0.999,
|
10 |
+
"seed": 1234,
|
11 |
+
|
12 |
+
"upsample_rates": [5,4,4,4,2],
|
13 |
+
"upsample_kernel_sizes": [11,8,8,8,4],
|
14 |
+
"upsample_initial_channel": 512,
|
15 |
+
"resblock_kernel_sizes": [3,7,11],
|
16 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
17 |
+
|
18 |
+
"multispkr": "from_input_file",
|
19 |
+
"multistyle": null,
|
20 |
+
|
21 |
+
"dur_prediction_weight": 1.0,
|
22 |
+
"dur_predictor_params": {
|
23 |
+
"encoder_embed_dim": 128,
|
24 |
+
"var_pred_hidden_dim": 128,
|
25 |
+
"var_pred_kernel_size": 3,
|
26 |
+
"var_pred_dropout": 0.5
|
27 |
+
},
|
28 |
+
|
29 |
+
"segment_size": 17920,
|
30 |
+
"code_hop_size": 640,
|
31 |
+
"f0_hop_size": 1280,
|
32 |
+
"style_hop_size": 16000,
|
33 |
+
|
34 |
+
"num_embeddings": 501,
|
35 |
+
"num_f0_tokens": 64,
|
36 |
+
"num_style_tokens": 100,
|
37 |
+
"num_speakers": 4,
|
38 |
+
|
39 |
+
"embedding_dim": 128,
|
40 |
+
"model_in_dim": 512,
|
41 |
+
|
42 |
+
"num_mels": 80,
|
43 |
+
"num_freq": 1025,
|
44 |
+
"n_fft": 1024,
|
45 |
+
"hop_size": 256,
|
46 |
+
"win_size": 1024,
|
47 |
+
|
48 |
+
"sampling_rate": 16000,
|
49 |
+
|
50 |
+
"fmin": 0,
|
51 |
+
"fmax": 8000,
|
52 |
+
"fmax_for_loss": null,
|
53 |
+
|
54 |
+
"num_workers": 4,
|
55 |
+
|
56 |
+
"dist_config": {
|
57 |
+
"dist_backend": "nccl",
|
58 |
+
"dist_url": "env://"
|
59 |
+
}
|
60 |
+
}
|
checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/generator.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9de97cbc336e6113c27f17560988a22e76df25af6a8e944ad01676aabec31326
|
3 |
+
size 59414584
|
checkpoints/speech_tokenizer/hifigan_spiritlm_expressive_w2v2/speakers.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ex04
|
2 |
+
ex02
|
3 |
+
ex03
|
4 |
+
ex01
|
checkpoints/speech_tokenizer/hubert_25hz/L11_quantizer_500.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:06b408c0a487f0218e8aba52ca30d9de54e0c36af8bebfd151265647d221080b
|
3 |
+
size 5222060
|
checkpoints/speech_tokenizer/hubert_25hz/mhubert_base_25hz.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1421f060cf92b9d2ea72dcecd7f30ce549e1112e24b62b43f2ec3026301051bb
|
3 |
+
size 383333938
|
checkpoints/speech_tokenizer/style_encoder_w2v2/config.json
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "facebook/wav2vec2-base",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"adapter_kernel_size": 3,
|
5 |
+
"adapter_stride": 2,
|
6 |
+
"add_adapter": false,
|
7 |
+
"apply_spec_augment": true,
|
8 |
+
"architectures": [
|
9 |
+
"Wav2Vec2ForPooledSequenceClassification"
|
10 |
+
],
|
11 |
+
"attention_dropout": 0.1,
|
12 |
+
"bos_token_id": 1,
|
13 |
+
"classifier_proj_size": 256,
|
14 |
+
"codevector_dim": 256,
|
15 |
+
"contrastive_logits_temperature": 0.1,
|
16 |
+
"conv_bias": false,
|
17 |
+
"conv_dim": [
|
18 |
+
512,
|
19 |
+
512,
|
20 |
+
512,
|
21 |
+
512,
|
22 |
+
512,
|
23 |
+
512,
|
24 |
+
512
|
25 |
+
],
|
26 |
+
"conv_kernel": [
|
27 |
+
10,
|
28 |
+
3,
|
29 |
+
3,
|
30 |
+
3,
|
31 |
+
3,
|
32 |
+
2,
|
33 |
+
2
|
34 |
+
],
|
35 |
+
"conv_stride": [
|
36 |
+
5,
|
37 |
+
2,
|
38 |
+
2,
|
39 |
+
2,
|
40 |
+
2,
|
41 |
+
2,
|
42 |
+
2
|
43 |
+
],
|
44 |
+
"ctc_loss_reduction": "sum",
|
45 |
+
"ctc_zero_infinity": false,
|
46 |
+
"diversity_loss_weight": 0.1,
|
47 |
+
"do_stable_layer_norm": false,
|
48 |
+
"eos_token_id": 2,
|
49 |
+
"feat_extract_activation": "gelu",
|
50 |
+
"feat_extract_norm": "group",
|
51 |
+
"feat_proj_dropout": 0.1,
|
52 |
+
"feat_quantizer_dropout": 0.0,
|
53 |
+
"final_dropout": 0.0,
|
54 |
+
"freeze_feat_extract_train": true,
|
55 |
+
"hidden_act": "gelu",
|
56 |
+
"hidden_dropout": 0.1,
|
57 |
+
"hidden_size": 768,
|
58 |
+
"id2label": {
|
59 |
+
"0": "0",
|
60 |
+
"1": "1",
|
61 |
+
"10": "10",
|
62 |
+
"11": "11",
|
63 |
+
"12": "12",
|
64 |
+
"13": "13",
|
65 |
+
"14": "14",
|
66 |
+
"15": "15",
|
67 |
+
"16": "16",
|
68 |
+
"17": "17",
|
69 |
+
"18": "18",
|
70 |
+
"19": "19",
|
71 |
+
"2": "2",
|
72 |
+
"20": "20",
|
73 |
+
"21": "21",
|
74 |
+
"22": "22",
|
75 |
+
"23": "23",
|
76 |
+
"24": "24",
|
77 |
+
"25": "25",
|
78 |
+
"26": "26",
|
79 |
+
"27": "27",
|
80 |
+
"28": "28",
|
81 |
+
"29": "29",
|
82 |
+
"3": "3",
|
83 |
+
"30": "30",
|
84 |
+
"31": "31",
|
85 |
+
"32": "32",
|
86 |
+
"33": "33",
|
87 |
+
"34": "34",
|
88 |
+
"35": "35",
|
89 |
+
"36": "36",
|
90 |
+
"37": "37",
|
91 |
+
"38": "38",
|
92 |
+
"39": "39",
|
93 |
+
"4": "4",
|
94 |
+
"40": "40",
|
95 |
+
"41": "41",
|
96 |
+
"42": "42",
|
97 |
+
"43": "43",
|
98 |
+
"44": "44",
|
99 |
+
"45": "45",
|
100 |
+
"46": "46",
|
101 |
+
"47": "47",
|
102 |
+
"48": "48",
|
103 |
+
"49": "49",
|
104 |
+
"5": "5",
|
105 |
+
"50": "50",
|
106 |
+
"51": "51",
|
107 |
+
"52": "52",
|
108 |
+
"53": "53",
|
109 |
+
"54": "54",
|
110 |
+
"55": "55",
|
111 |
+
"56": "56",
|
112 |
+
"57": "57",
|
113 |
+
"58": "58",
|
114 |
+
"59": "59",
|
115 |
+
"6": "6",
|
116 |
+
"60": "60",
|
117 |
+
"61": "61",
|
118 |
+
"62": "62",
|
119 |
+
"63": "63",
|
120 |
+
"64": "64",
|
121 |
+
"65": "65",
|
122 |
+
"66": "66",
|
123 |
+
"67": "67",
|
124 |
+
"68": "68",
|
125 |
+
"69": "69",
|
126 |
+
"7": "7",
|
127 |
+
"70": "70",
|
128 |
+
"71": "71",
|
129 |
+
"72": "72",
|
130 |
+
"73": "73",
|
131 |
+
"74": "74",
|
132 |
+
"75": "75",
|
133 |
+
"76": "76",
|
134 |
+
"77": "77",
|
135 |
+
"78": "78",
|
136 |
+
"79": "79",
|
137 |
+
"8": "8",
|
138 |
+
"80": "80",
|
139 |
+
"81": "81",
|
140 |
+
"82": "82",
|
141 |
+
"83": "83",
|
142 |
+
"84": "84",
|
143 |
+
"85": "85",
|
144 |
+
"86": "86",
|
145 |
+
"87": "87",
|
146 |
+
"88": "88",
|
147 |
+
"89": "89",
|
148 |
+
"9": "9",
|
149 |
+
"90": "90",
|
150 |
+
"91": "91",
|
151 |
+
"92": "92",
|
152 |
+
"93": "93",
|
153 |
+
"94": "94",
|
154 |
+
"95": "95",
|
155 |
+
"96": "96",
|
156 |
+
"97": "97",
|
157 |
+
"98": "98",
|
158 |
+
"99": "99"
|
159 |
+
},
|
160 |
+
"initializer_range": 0.02,
|
161 |
+
"intermediate_size": 3072,
|
162 |
+
"label2id": {
|
163 |
+
"0": "0",
|
164 |
+
"1": "1",
|
165 |
+
"10": "10",
|
166 |
+
"11": "11",
|
167 |
+
"12": "12",
|
168 |
+
"13": "13",
|
169 |
+
"14": "14",
|
170 |
+
"15": "15",
|
171 |
+
"16": "16",
|
172 |
+
"17": "17",
|
173 |
+
"18": "18",
|
174 |
+
"19": "19",
|
175 |
+
"2": "2",
|
176 |
+
"20": "20",
|
177 |
+
"21": "21",
|
178 |
+
"22": "22",
|
179 |
+
"23": "23",
|
180 |
+
"24": "24",
|
181 |
+
"25": "25",
|
182 |
+
"26": "26",
|
183 |
+
"27": "27",
|
184 |
+
"28": "28",
|
185 |
+
"29": "29",
|
186 |
+
"3": "3",
|
187 |
+
"30": "30",
|
188 |
+
"31": "31",
|
189 |
+
"32": "32",
|
190 |
+
"33": "33",
|
191 |
+
"34": "34",
|
192 |
+
"35": "35",
|
193 |
+
"36": "36",
|
194 |
+
"37": "37",
|
195 |
+
"38": "38",
|
196 |
+
"39": "39",
|
197 |
+
"4": "4",
|
198 |
+
"40": "40",
|
199 |
+
"41": "41",
|
200 |
+
"42": "42",
|
201 |
+
"43": "43",
|
202 |
+
"44": "44",
|
203 |
+
"45": "45",
|
204 |
+
"46": "46",
|
205 |
+
"47": "47",
|
206 |
+
"48": "48",
|
207 |
+
"49": "49",
|
208 |
+
"5": "5",
|
209 |
+
"50": "50",
|
210 |
+
"51": "51",
|
211 |
+
"52": "52",
|
212 |
+
"53": "53",
|
213 |
+
"54": "54",
|
214 |
+
"55": "55",
|
215 |
+
"56": "56",
|
216 |
+
"57": "57",
|
217 |
+
"58": "58",
|
218 |
+
"59": "59",
|
219 |
+
"6": "6",
|
220 |
+
"60": "60",
|
221 |
+
"61": "61",
|
222 |
+
"62": "62",
|
223 |
+
"63": "63",
|
224 |
+
"64": "64",
|
225 |
+
"65": "65",
|
226 |
+
"66": "66",
|
227 |
+
"67": "67",
|
228 |
+
"68": "68",
|
229 |
+
"69": "69",
|
230 |
+
"7": "7",
|
231 |
+
"70": "70",
|
232 |
+
"71": "71",
|
233 |
+
"72": "72",
|
234 |
+
"73": "73",
|
235 |
+
"74": "74",
|
236 |
+
"75": "75",
|
237 |
+
"76": "76",
|
238 |
+
"77": "77",
|
239 |
+
"78": "78",
|
240 |
+
"79": "79",
|
241 |
+
"8": "8",
|
242 |
+
"80": "80",
|
243 |
+
"81": "81",
|
244 |
+
"82": "82",
|
245 |
+
"83": "83",
|
246 |
+
"84": "84",
|
247 |
+
"85": "85",
|
248 |
+
"86": "86",
|
249 |
+
"87": "87",
|
250 |
+
"88": "88",
|
251 |
+
"89": "89",
|
252 |
+
"9": "9",
|
253 |
+
"90": "90",
|
254 |
+
"91": "91",
|
255 |
+
"92": "92",
|
256 |
+
"93": "93",
|
257 |
+
"94": "94",
|
258 |
+
"95": "95",
|
259 |
+
"96": "96",
|
260 |
+
"97": "97",
|
261 |
+
"98": "98",
|
262 |
+
"99": "99"
|
263 |
+
},
|
264 |
+
"layer_norm_eps": 1e-05,
|
265 |
+
"layerdrop": 0.0,
|
266 |
+
"mask_channel_length": 10,
|
267 |
+
"mask_channel_min_space": 1,
|
268 |
+
"mask_channel_other": 0.0,
|
269 |
+
"mask_channel_prob": 0.0,
|
270 |
+
"mask_channel_selection": "static",
|
271 |
+
"mask_feature_length": 10,
|
272 |
+
"mask_feature_min_masks": 0,
|
273 |
+
"mask_feature_prob": 0.0,
|
274 |
+
"mask_time_length": 10,
|
275 |
+
"mask_time_min_masks": 2,
|
276 |
+
"mask_time_min_space": 1,
|
277 |
+
"mask_time_other": 0.0,
|
278 |
+
"mask_time_prob": 0.05,
|
279 |
+
"mask_time_selection": "static",
|
280 |
+
"model_type": "wav2vec2",
|
281 |
+
"no_mask_channel_overlap": false,
|
282 |
+
"no_mask_time_overlap": false,
|
283 |
+
"num_adapter_layers": 3,
|
284 |
+
"num_attention_heads": 12,
|
285 |
+
"num_codevector_groups": 2,
|
286 |
+
"num_codevectors_per_group": 320,
|
287 |
+
"num_conv_pos_embedding_groups": 16,
|
288 |
+
"num_conv_pos_embeddings": 128,
|
289 |
+
"num_feat_extract_layers": 7,
|
290 |
+
"num_hidden_layers": 12,
|
291 |
+
"num_negatives": 100,
|
292 |
+
"output_hidden_size": 768,
|
293 |
+
"pad_token_id": 0,
|
294 |
+
"proj_codevector_dim": 256,
|
295 |
+
"tdnn_dilation": [
|
296 |
+
1,
|
297 |
+
2,
|
298 |
+
3,
|
299 |
+
1,
|
300 |
+
1
|
301 |
+
],
|
302 |
+
"tdnn_dim": [
|
303 |
+
512,
|
304 |
+
512,
|
305 |
+
512,
|
306 |
+
512,
|
307 |
+
1500
|
308 |
+
],
|
309 |
+
"tdnn_kernel": [
|
310 |
+
5,
|
311 |
+
3,
|
312 |
+
3,
|
313 |
+
1,
|
314 |
+
1
|
315 |
+
],
|
316 |
+
"torch_dtype": "float32",
|
317 |
+
"transformers_version": "4.25.1",
|
318 |
+
"use_weighted_layer_sum": false,
|
319 |
+
"vocab_size": 32,
|
320 |
+
"xvector_output_dim": 512
|
321 |
+
}
|
checkpoints/speech_tokenizer/style_encoder_w2v2/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:819b7a87bfcf4041252c3f1825915924e10d5a06a2a66da546beaca98c7ab8bc
|
3 |
+
size 378451177
|
checkpoints/speech_tokenizer/vqvae_f0_quantizer/config.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
seed: 1234
|
3 |
+
|
4 |
+
# Data
|
5 |
+
f0_path: ''
|
6 |
+
p_train: 0.95
|
7 |
+
min_frames: null
|
8 |
+
batch_size: 128
|
9 |
+
features: f0_interp,vuv
|
10 |
+
out_features: norm_f0_interp,vuv
|
11 |
+
segment_size: null
|
12 |
+
segment_multi: 16
|
13 |
+
num_workers: 4
|
14 |
+
vuv_scale: 2
|
15 |
+
speaker_stats: ''
|
16 |
+
recon_loss_fn: l1_loss
|
17 |
+
|
18 |
+
|
19 |
+
# Optimization
|
20 |
+
learning_rate: 0.0002
|
21 |
+
adam_b1: 0.8
|
22 |
+
adam_b2: 0.99
|
23 |
+
lr_decay: 0.999
|
24 |
+
lambda_commit: 0.02
|
25 |
+
|
26 |
+
# VQ params
|
27 |
+
vq_params:
|
28 |
+
l_bins: 64
|
29 |
+
emb_width: 128
|
30 |
+
mu: 0.99
|
31 |
+
levels: 1
|
32 |
+
|
33 |
+
# Encoder params
|
34 |
+
encoder_params:
|
35 |
+
input_emb_width: 2
|
36 |
+
output_emb_width: 128
|
37 |
+
levels: 1
|
38 |
+
downs_t:
|
39 |
+
- 4
|
40 |
+
strides_t:
|
41 |
+
- 2
|
42 |
+
width: 32
|
43 |
+
depth: 4
|
44 |
+
m_conv: 1.0
|
45 |
+
dilation_growth_rate: 3
|
46 |
+
|
47 |
+
# Decoder params
|
48 |
+
decoder_params:
|
49 |
+
input_emb_width: 2
|
50 |
+
output_emb_width: 128
|
51 |
+
levels: 1
|
52 |
+
downs_t:
|
53 |
+
- 4
|
54 |
+
strides_t:
|
55 |
+
- 2
|
56 |
+
width: 32
|
57 |
+
depth: 4
|
58 |
+
m_conv: 1.0
|
59 |
+
dilation_growth_rate: 3
|
checkpoints/speech_tokenizer/vqvae_f0_quantizer/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4321b0c19b47279ab3d76c4b6e85bbc439156a4dd6478919856accaf4180382
|
3 |
+
size 2600601
|
data/examples/pred.jsonl
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"pred": "angry", "id": 4792320029370491913}
|
2 |
+
{"pred": "neutral", "id": -5682350483296949563}
|
3 |
+
{"pred": "amused", "id": -8754508989367964614}
|
4 |
+
{"pred": "angry", "id": -9018665079841831624}
|
5 |
+
{"pred": "neutral", "id": 1159246029716120600}
|
data/examples/ref.jsonl
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"emotion": "angry", "sentiment": "negative", "wav_path": "emov/sam/Angry/anger_281-308_0286.wav", "split": "test", "speaker": "sam", "id": 4792320029370491913}
|
2 |
+
{"emotion": "neutral", "sentiment": "neutral", "wav_path": "emov/sam/Neutral/neutral_281-308_0286.wav", "split": "test", "speaker": "sam", "id": -5682350483296949563}
|
3 |
+
{"emotion": "amused", "sentiment": "positive", "wav_path": "emov/sam/Amused/amused_281-308_0286.wav", "split": "test", "speaker": "sam", "id": -8754508989367964614}
|
4 |
+
{"emotion": "angry", "sentiment": "negative", "wav_path": "emov/jenie/Angry/anger_57-84_0084.wav", "split": "test", "speaker": "jenie", "id": -9018665079841831624}
|
5 |
+
{"emotion": "neutral", "sentiment": "neutral", "wav_path": "emov/jenie/Neutral/neutral_57-84_0084.wav", "split": "test", "speaker": "jenie", "id": 1159246029716120600}
|
env.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: spiritlm_test
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
dependencies:
|
6 |
+
- python=3.9
|
7 |
+
- pip
|
8 |
+
- pytorch-cuda=11.8
|
9 |
+
- pytorch
|
10 |
+
- torchaudio
|
11 |
+
- pip:
|
12 |
+
- omegaconf==2.2.0
|
13 |
+
- librosa~=0.10
|
14 |
+
- local-attention~=1.9
|
15 |
+
- encodec~=0.1
|
16 |
+
- transformers
|
17 |
+
- fairscale~=0.4
|
18 |
+
- sentencepiece
|
19 |
+
- torchfcpe~=0.0.4
|
examples/audio/7143-88743-0029.flac
ADDED
Binary file (147 kB). View file
|
|
examples/distributed_inference_recipe/multi_nodes.slurm
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
7 |
+
# found in the LICENSE file in the root directory of this source tree.
|
8 |
+
|
9 |
+
#SBATCH --job-name=spiritlm
|
10 |
+
#SBATCH --ntasks-per-node=1
|
11 |
+
#SBATCH --gpus-per-node=8
|
12 |
+
#SBATCH --nodes=2
|
13 |
+
#SBATCH --cpus-per-task=12
|
14 |
+
#SBATCH --output=./logs/%j.stdout
|
15 |
+
#SBATCH --error=./logs/%j.stderr
|
16 |
+
#SBATCH --time=01:00:00
|
17 |
+
|
18 |
+
set -e
|
19 |
+
|
20 |
+
srun bash -c 'torchrun --nnodes $SLURM_JOB_NUM_NODES --nproc-per-node $SLURM_GPUS_ON_NODE \
|
21 |
+
--node-rank $SLURM_PROCID \
|
22 |
+
--master-addr $(scontrol show hostnames $SLURM_NODELIST | head -n1) \
|
23 |
+
--master-port 12345 \
|
24 |
+
examples/distributed_inference_recipe/run_dist.py'
|
examples/distributed_inference_recipe/run_dist.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Usage example:
|
9 |
+
|
10 |
+
cd {SPIRITLM ROOT FOLDER}
|
11 |
+
export PYTHONPATH=.
|
12 |
+
|
13 |
+
Single node, multi-gpus:
|
14 |
+
(Assume that your machine has 8 GPUs)
|
15 |
+
torchrun --nnodes 1 --nproc-per-node 8 examples/distributed_inference_recipe/run_dist.py
|
16 |
+
|
17 |
+
Multi-nodes, multi-gpus:
|
18 |
+
(2 nodes, 8 GPUs for eahc node, via sbatch)
|
19 |
+
mkdir -p logs
|
20 |
+
sbatch examples/distributed_inference_recipe/multi_nodes.slurm
|
21 |
+
"""
|
22 |
+
|
23 |
+
import os
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.distributed as dist
|
27 |
+
import torchaudio
|
28 |
+
from spiritlm.model.spiritlm_model import (
|
29 |
+
ContentType,
|
30 |
+
GenerationInput,
|
31 |
+
OutputModality,
|
32 |
+
Spiritlm,
|
33 |
+
)
|
34 |
+
from torch.utils.data import TensorDataset
|
35 |
+
from torch.utils.data.distributed import DistributedSampler
|
36 |
+
from transformers import GenerationConfig, set_seed
|
37 |
+
|
38 |
+
|
39 |
+
def run(seed: int = 0):
|
40 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
41 |
+
world_rank = int(os.environ["RANK"])
|
42 |
+
print(
|
43 |
+
f"Running distributed inference with world_size: {world_size}, world_rank: {world_rank}"
|
44 |
+
)
|
45 |
+
dist.init_process_group("nccl", rank=world_rank, world_size=world_size)
|
46 |
+
|
47 |
+
set_seed(seed)
|
48 |
+
|
49 |
+
wav = torchaudio.load("examples/audio/7143-88743-0029.flac")[0].squeeze()
|
50 |
+
|
51 |
+
# fake repeated dataset
|
52 |
+
dataset = TensorDataset(wav.repeat(32, 1))
|
53 |
+
|
54 |
+
sampler = DistributedSampler(dataset=dataset)
|
55 |
+
loader = torch.utils.data.DataLoader(
|
56 |
+
dataset=dataset,
|
57 |
+
batch_size=1, # don't change
|
58 |
+
sampler=sampler,
|
59 |
+
num_workers=4,
|
60 |
+
)
|
61 |
+
|
62 |
+
spirit_lm = Spiritlm("spirit-lm-expressive-7b")
|
63 |
+
|
64 |
+
for _, data in enumerate(loader):
|
65 |
+
outs = spirit_lm.generate(
|
66 |
+
output_modality=OutputModality.ARBITRARY,
|
67 |
+
interleaved_inputs=[
|
68 |
+
GenerationInput(
|
69 |
+
content=data[0], # 0 because of batch size 1
|
70 |
+
content_type=ContentType.SPEECH,
|
71 |
+
)
|
72 |
+
],
|
73 |
+
generation_config=GenerationConfig(
|
74 |
+
temperature=0.9,
|
75 |
+
top_p=0.95,
|
76 |
+
max_new_tokens=200,
|
77 |
+
do_sample=True,
|
78 |
+
),
|
79 |
+
)
|
80 |
+
print(f"outs: {outs}")
|
81 |
+
|
82 |
+
|
83 |
+
def setup_env():
|
84 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
setup_env()
|
89 |
+
run()
|
examples/speech_generation/spirit_model.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.dev.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pytest
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf>=2.2.0
|
2 |
+
librosa>=0.10
|
3 |
+
local-attention>=1.9
|
4 |
+
encodec>=0.1
|
5 |
+
transformers
|
6 |
+
fairscale>=0.4
|
7 |
+
sentencepiece
|
8 |
+
pyarrow>=14.0
|
9 |
+
torchfcpe>=0.0.4
|
setup.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from setuptools import find_packages, setup
|
11 |
+
|
12 |
+
NAME = "spiritlm"
|
13 |
+
VERSION = "0.1.0"
|
14 |
+
DESCRIPTION = "Interleaved Spoken and Written Language Model"
|
15 |
+
URL = "https://github.com/facebookresearch/spiritlm"
|
16 |
+
KEYWORDS = [
|
17 |
+
"Language Model, Speech Language Model, Multimodal, Crossmodal, Expressivity Modeling"
|
18 |
+
]
|
19 |
+
LICENSE = "FAIR Noncommercial Research License"
|
20 |
+
|
21 |
+
|
22 |
+
def _get_long_description():
|
23 |
+
with (Path(__file__).parent / "README.md").open(encoding="utf-8") as file:
|
24 |
+
long_description = file.read()
|
25 |
+
return long_description
|
26 |
+
|
27 |
+
|
28 |
+
def _read_reqs(relpath):
|
29 |
+
fullpath = os.path.join(os.path.dirname(__file__), relpath)
|
30 |
+
with open(fullpath) as f:
|
31 |
+
return [
|
32 |
+
s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
setup(
|
37 |
+
name=NAME,
|
38 |
+
version=VERSION,
|
39 |
+
description=DESCRIPTION,
|
40 |
+
long_description=_get_long_description(),
|
41 |
+
long_description_content_type="text/plain",
|
42 |
+
url=URL,
|
43 |
+
license=LICENSE,
|
44 |
+
author="Meta",
|
45 |
+
keywords=KEYWORDS,
|
46 |
+
classifiers=[
|
47 |
+
"Intended Audience :: Science/Research",
|
48 |
+
"License :: FAIR Noncommercial Research License",
|
49 |
+
"Topic :: Multimedia :: Sound/Audio",
|
50 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
51 |
+
],
|
52 |
+
packages=find_packages(),
|
53 |
+
zip_safe=False,
|
54 |
+
python_requires=">=3.9",
|
55 |
+
install_requires=_read_reqs("requirements.txt"),
|
56 |
+
extras_require={
|
57 |
+
"dev": ["pytest"],
|
58 |
+
"eval": ["pandas"],
|
59 |
+
},
|
60 |
+
)
|
spiritlm.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: spiritlm
|
3 |
+
Version: 0.1.0
|
4 |
+
Summary: Interleaved Spoken and Written Language Model
|
5 |
+
Home-page: https://github.com/facebookresearch/spiritlm
|
6 |
+
Author: Meta
|
7 |
+
License: FAIR Noncommercial Research License
|
8 |
+
Keywords: Language Model, Speech Language Model, Multimodal, Crossmodal, Expressivity Modeling
|
9 |
+
Classifier: Intended Audience :: Science/Research
|
10 |
+
Classifier: License :: FAIR Noncommercial Research License
|
11 |
+
Classifier: Topic :: Multimedia :: Sound/Audio
|
12 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
13 |
+
Requires-Python: >=3.9
|
14 |
+
Description-Content-Type: text/plain
|
15 |
+
License-File: LICENSE
|
16 |
+
Requires-Dist: omegaconf>=2.2.0
|
17 |
+
Requires-Dist: librosa>=0.10
|
18 |
+
Requires-Dist: local-attention>=1.9
|
19 |
+
Requires-Dist: encodec>=0.1
|
20 |
+
Requires-Dist: transformers
|
21 |
+
Requires-Dist: fairscale>=0.4
|
22 |
+
Requires-Dist: sentencepiece
|
23 |
+
Requires-Dist: pyarrow>=14.0
|
24 |
+
Requires-Dist: torchfcpe>=0.0.4
|
25 |
+
Provides-Extra: dev
|
26 |
+
Requires-Dist: pytest; extra == "dev"
|
27 |
+
Provides-Extra: eval
|
28 |
+
Requires-Dist: pandas; extra == "eval"
|
29 |
+
|
30 |
+
# Meta Spirit LM: Interleaved Spoken and Written Language Model
|
31 |
+
|
32 |
+
This repository contains the model weights, inference code and evaluation scripts for the Spirit LM [paper](https://arxiv.org/pdf/2402.05755.pdf). You can find more generation samples on our [demo page](https://speechbot.github.io/spiritlm/).
|
33 |
+
|
34 |
+
## Spirit LM Model Overview
|
35 |
+
<img src="assets/spiritlm_overview.png">
|
36 |
+
|
37 |
+
## Installation Setup
|
38 |
+
### Conda
|
39 |
+
```
|
40 |
+
conda env create -f env.yml
|
41 |
+
pip install -e '.[eval]'
|
42 |
+
|
43 |
+
```
|
44 |
+
### Pip
|
45 |
+
```
|
46 |
+
pip install -e '.[eval]'
|
47 |
+
```
|
48 |
+
|
49 |
+
### Dev
|
50 |
+
(Optionally, use only if you want to run the tests.)
|
51 |
+
```
|
52 |
+
pip install -e '.[dev]'
|
53 |
+
```
|
54 |
+
|
55 |
+
## Checkpoints Setup
|
56 |
+
See [checkpoints/README.md](checkpoints/README.md)
|
57 |
+
|
58 |
+
## Quick Start
|
59 |
+
### Speech Tokenization
|
60 |
+
See [spiritlm/speech_tokenizer/README.md](spiritlm/speech_tokenizer/README.md)
|
61 |
+
### Spirit LM Generation
|
62 |
+
See [spiritlm/model/README.md](spiritlm/model/README.md)
|
63 |
+
### Speech-Text Sentiment Preservation benchmark (STSP)
|
64 |
+
See [spiritlm/eval/README.md](spiritlm/eval/README.md)
|
65 |
+
|
66 |
+
## Model Card
|
67 |
+
More details of the model can be found in [MODEL_CARD.md](MODEL_CARD.md).
|
68 |
+
|
69 |
+
## License
|
70 |
+
The present code is provided under the **FAIR Noncommercial Research License** found in [LICENSE](LICENSE).
|
71 |
+
|
72 |
+
## Citation
|
73 |
+
```
|
74 |
+
@misc{nguyen2024spiritlminterleavedspokenwritten,
|
75 |
+
title={SpiRit-LM: Interleaved Spoken and Written Language Model},
|
76 |
+
author={Tu Anh Nguyen and Benjamin Muller and Bokai Yu and Marta R. Costa-jussa and Maha Elbayad and Sravya Popuri and Paul-Ambroise Duquenne and Robin Algayres and Ruslan Mavlyutov and Itai Gat and Gabriel Synnaeve and Juan Pino and Benoit Sagot and Emmanuel Dupoux},
|
77 |
+
year={2024},
|
78 |
+
eprint={2402.05755},
|
79 |
+
archivePrefix={arXiv},
|
80 |
+
primaryClass={cs.CL},
|
81 |
+
url={https://arxiv.org/abs/2402.05755},
|
82 |
+
}
|
83 |
+
```
|
84 |
+
|
spiritlm.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE
|
2 |
+
README.md
|
3 |
+
setup.py
|
4 |
+
spiritlm/__init__.py
|
5 |
+
spiritlm.egg-info/PKG-INFO
|
6 |
+
spiritlm.egg-info/SOURCES.txt
|
7 |
+
spiritlm.egg-info/dependency_links.txt
|
8 |
+
spiritlm.egg-info/not-zip-safe
|
9 |
+
spiritlm.egg-info/requires.txt
|
10 |
+
spiritlm.egg-info/top_level.txt
|
11 |
+
spiritlm/model/__init__.py
|
12 |
+
spiritlm/model/spiritlm_model.py
|
13 |
+
spiritlm/model/utils.py
|
14 |
+
spiritlm/speech_tokenizer/__init__.py
|
15 |
+
spiritlm/speech_tokenizer/spiritlm_tokenizer.py
|
16 |
+
spiritlm/speech_tokenizer/f0/__init__.py
|
17 |
+
spiritlm/speech_tokenizer/f0/f0_extractor.py
|
18 |
+
spiritlm/speech_tokenizer/f0/f0_tokenizer.py
|
19 |
+
spiritlm/speech_tokenizer/f0/vqvae.py
|
20 |
+
spiritlm/speech_tokenizer/hifigan/__init__.py
|
21 |
+
spiritlm/speech_tokenizer/hifigan/hifigan_vocoder.py
|
22 |
+
spiritlm/speech_tokenizer/hubert/__init__.py
|
23 |
+
spiritlm/speech_tokenizer/hubert/hubert_tokenizer.py
|
24 |
+
spiritlm/speech_tokenizer/hubert/quantizer_model.py
|
25 |
+
spiritlm/speech_tokenizer/hubert/hubert_model/__init__.py
|
26 |
+
spiritlm/speech_tokenizer/hubert/hubert_model/hubert_model.py
|
27 |
+
spiritlm/speech_tokenizer/hubert/hubert_model/wav2vec2_model.py
|
28 |
+
spiritlm/speech_tokenizer/style_encoder/__init__.py
|
29 |
+
spiritlm/speech_tokenizer/style_encoder/w2v2_encoder.py
|
30 |
+
tests/__init__.py
|
31 |
+
tests/test_spirit_model.py
|
32 |
+
tests/test_tokenizer.py
|
spiritlm.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
spiritlm.egg-info/not-zip-safe
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
spiritlm.egg-info/requires.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf>=2.2.0
|
2 |
+
librosa>=0.10
|
3 |
+
local-attention>=1.9
|
4 |
+
encodec>=0.1
|
5 |
+
transformers
|
6 |
+
fairscale>=0.4
|
7 |
+
sentencepiece
|
8 |
+
pyarrow>=14.0
|
9 |
+
torchfcpe>=0.0.4
|
10 |
+
|
11 |
+
[dev]
|
12 |
+
pytest
|
13 |
+
|
14 |
+
[eval]
|
15 |
+
pandas
|
spiritlm.egg-info/top_level.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
spiritlm
|
2 |
+
tests
|
spiritlm/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
spiritlm/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (162 Bytes). View file
|
|
spiritlm/eval/README.md
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# STSP Evaluation
|
2 |
+
The Speech-Text Sentiment Preservation (STSP) benchmark is made of a collection of speech and text prompts in the positive, negative or neutral sentiment.
|
3 |
+
Given a spoken or written prompt , the task consists in generating a text or speech sequence of tokens that preserves the sentiment of the prompt.
|
4 |
+
|
5 |
+
The sentiment of the prompt is evaluated automatically with a sentiment/emotion classifier in speech or text depending of the output modality.
|
6 |
+
Based on these, we derive a STSP accuracy score.
|
7 |
+
|
8 |
+
## Data Download
|
9 |
+
Download the data as well as the speech/text classifier checkpoints via this [link](https://dl.fbaipublicfiles.com/textless_nlp/spiritlm/stsp.tar.gz)
|
10 |
+
then extract the data into the folder `{spiritlm ROOT FOLDER}/data/stsp_data`
|
11 |
+
```
|
12 |
+
cd {spiritlm ROOT FOLDER}
|
13 |
+
mkdir data/stsp_data
|
14 |
+
tar -xvzf stsp.tar.gz -C data/stsp_data --strip-components=1
|
15 |
+
```
|
16 |
+
Run the following script to check the dataset is all correctly present:
|
17 |
+
```
|
18 |
+
python spiritlm/eval/stsp/sanity_check_download.py
|
19 |
+
```
|
20 |
+
## Data structure
|
21 |
+
The dataset contains 3 folders:
|
22 |
+
- `data`: raw audio files
|
23 |
+
- `manifest`: data splits
|
24 |
+
- `model`: speech/text classifier checkpoints
|
25 |
+
### Data
|
26 |
+
The raw audio files for
|
27 |
+
- `emov`: EMOV
|
28 |
+
- `expresso/conversational`: EXPRESSO-ASR
|
29 |
+
- `expresso/read`: EXPRESSO-READ
|
30 |
+
|
31 |
+
### Manifest
|
32 |
+
The train/validation/test splits, concretely we have:
|
33 |
+
|
34 |
+
#### EMOV
|
35 |
+
- 1053 records for emov train split at `manifest/emov/emov.train.jsonl`
|
36 |
+
- 351 records for emov dev split at `manifest/emov/emov.dev.jsonl`
|
37 |
+
- 351 records for emov test split at `manifest/emov/emov.test.jsonl`
|
38 |
+
|
39 |
+
#### EXPRESSO-ASR
|
40 |
+
- 1373 records for EXPRESSO-ASR train split at `manifest/expresso/expresso_asr.train`
|
41 |
+
- 479 records for EXPRESSO-ASR dev at `manifest/expresso/expresso_asr.dev.jsonl`
|
42 |
+
- 462 records for EXPRESSO-ASR test split at `manifest/expresso/expresso_asr.test.jsonl`
|
43 |
+
|
44 |
+
#### EXPRESSO-READ
|
45 |
+
- 1024 records for EXPRESSO-READ train split at `manifest/expresso/expresso_read.train`
|
46 |
+
- 60 records for EXPRESSO-READ dev at `manifest/expresso/expresso_read.dev.jsonl`
|
47 |
+
- 54 records for EXPRESSO-READ test split at `manifest/expresso/expresso_read.test.jsonl`
|
48 |
+
|
49 |
+
#### Few-shot Samples
|
50 |
+
The subset from EXPRESSO-ASR training set, used for the few-shot experiments:
|
51 |
+
- `s2s.jsonl`: S -> S direction
|
52 |
+
- `s2t.jsonl`: S -> T direction
|
53 |
+
- `t2t.jsonl`: T -> T direction
|
54 |
+
- `t2s.jsonl`: T -> S direction
|
55 |
+
|
56 |
+
### Auto-Eval Speech And Text Classifiers
|
57 |
+
|
58 |
+
The sentiment of the generated sequence is estimated in an auto-eval fashion with Speech and Text classifiers. We point to the [paper](https://arxiv.org/abs/2402.05755) for details on these classifiers.
|
59 |
+
|
60 |
+
|
61 |
+
## Prediction & Evaluation of Spirit LM on STSP (Speech/Text)
|
62 |
+
|
63 |
+
```export PYTHONPATH=.```
|
64 |
+
|
65 |
+
Set `spiritlm` to the model you want to evaluate: e.g. ```spiritlm=spirit-lm-base-7b``` or ```spiritlm=spirit-lm-expressive-7b```
|
66 |
+
|
67 |
+
#### Speech to Text
|
68 |
+
torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_s_t.jsonl --input_output speech_text
|
69 |
+
#### Text to Text
|
70 |
+
torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_t_t.jsonl --input_output text_text
|
71 |
+
#### Text to Speech
|
72 |
+
torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_t_s.jsonl --input_output text_speech
|
73 |
+
#### Speech to Speech
|
74 |
+
torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_s_s.jsonl --input_output speech_speech
|
75 |
+
|
76 |
+
|
77 |
+
### Post-hoc Evaluation
|
78 |
+
|
79 |
+
To evaluate the performance of a model different from SpiritLM, you can use the following evaluation script that takes as input a prediction.jsonl file.
|
80 |
+
|
81 |
+
```
|
82 |
+
python spiritlm/eval/eval_stsp.py --ref_file $REF_FILE --pred_file $pred_file
|
83 |
+
```
|
84 |
+
|
85 |
+
e.g.
|
86 |
+
|
87 |
+
```
|
88 |
+
python spiritlm/eval/eval_stsp.py \
|
89 |
+
--ref_file ./data/examples/demo.jsonl \
|
90 |
+
--pred_file ./data/examples/pred.jsonl
|
91 |
+
> Accuracy: 100.00% for predictions ./data/examples/pred.jsonl
|
92 |
+
```
|
spiritlm/eval/eval_stsp.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
from typing import Dict, Union
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
from spiritlm.eval.stsp.utils import EMOTION_2_SENTIMENT
|
13 |
+
|
14 |
+
|
15 |
+
def load_pred(predictions):
|
16 |
+
ret = {}
|
17 |
+
with open(predictions) as f:
|
18 |
+
for line in f:
|
19 |
+
pred = json.loads(line)
|
20 |
+
ret[str(pred["id"])] = pred["pred"]
|
21 |
+
|
22 |
+
assert sum(1 for _ in open(predictions)) == len(ret)
|
23 |
+
|
24 |
+
return ret
|
25 |
+
|
26 |
+
|
27 |
+
def eval(
|
28 |
+
gold_records: str, predictions: Union[str, Dict], info_data="", label="sentiment"
|
29 |
+
):
|
30 |
+
n_gold_records = sum(1 for _ in open(gold_records))
|
31 |
+
n_lines_pred = (
|
32 |
+
sum(1 for _ in open(predictions))
|
33 |
+
if isinstance(predictions, str)
|
34 |
+
else len(predictions)
|
35 |
+
)
|
36 |
+
assert (
|
37 |
+
n_gold_records == n_lines_pred
|
38 |
+
), f"Mismatch between prediction ({n_lines_pred} samples in {predictions}) and reference ({n_gold_records} in {gold_records})"
|
39 |
+
|
40 |
+
pred_dic = load_pred(predictions) if isinstance(predictions, str) else predictions
|
41 |
+
scores = []
|
42 |
+
|
43 |
+
with open(gold_records) as gold:
|
44 |
+
for line in gold:
|
45 |
+
ref = json.loads(line)
|
46 |
+
try:
|
47 |
+
if label in ref:
|
48 |
+
scores.append(pred_dic[str(ref["id"])] == ref[label])
|
49 |
+
else:
|
50 |
+
assert label == "sentiment" and "emotion" in ref, ref
|
51 |
+
sentiment = EMOTION_2_SENTIMENT[ref["emotion"]]
|
52 |
+
scores.append(pred_dic[str(ref["id"])] == sentiment)
|
53 |
+
except Exception as e:
|
54 |
+
print(
|
55 |
+
f"ERROR in matching the predicted labels with the gold ones: {e}: ref['id'] do not match any key in {pred_dic}', {ref['id']}: "
|
56 |
+
)
|
57 |
+
# TODO: add other metrics if needed : F1 per class, etc.
|
58 |
+
report = pd.DataFrame({"Correct": scores})
|
59 |
+
if isinstance(predictions, str):
|
60 |
+
info_data += f"from {predictions}"
|
61 |
+
print(
|
62 |
+
f"Accuracy: {(report['Correct']==1).sum()/len(report)*100:0.2f}% for predictions {info_data}"
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
parser = argparse.ArgumentParser()
|
68 |
+
|
69 |
+
parser.add_argument(
|
70 |
+
"--ref_file",
|
71 |
+
type=str,
|
72 |
+
help="Path to reference record",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--pred_file",
|
76 |
+
type=str,
|
77 |
+
help="Path to prediction: should be jsonl with each entry {'pred': , 'id': }",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--label",
|
81 |
+
type=str,
|
82 |
+
default="sentiment",
|
83 |
+
help="sentiment or emotion",
|
84 |
+
)
|
85 |
+
args = parser.parse_args()
|
86 |
+
|
87 |
+
eval(args.ref_file, args.pred_file, label=args.label)
|
spiritlm/eval/load_data.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchaudio
|
12 |
+
|
13 |
+
|
14 |
+
class SpeechData(torch.utils.data.Dataset):
|
15 |
+
def __init__(self, manifest_dir, root_dir=None):
|
16 |
+
if root_dir is None:
|
17 |
+
root_dir = "."
|
18 |
+
self.root_dir = Path(root_dir)
|
19 |
+
self.manifest_dir = self.root_dir / manifest_dir
|
20 |
+
self.wav_field = "wav_path"
|
21 |
+
self.manifest = [json.loads(line.strip()) for line in open(manifest_dir)]
|
22 |
+
|
23 |
+
def __getitem__(self, idx):
|
24 |
+
wav_path = self.root_dir / self.manifest[idx][self.wav_field]
|
25 |
+
return {
|
26 |
+
"wav": torchaudio.load(wav_path)[0].squeeze(0),
|
27 |
+
"id": str(self.manifest[idx]["id"]),
|
28 |
+
}
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.manifest)
|
32 |
+
|
33 |
+
|
34 |
+
class TextData(torch.utils.data.Dataset):
|
35 |
+
def __init__(self, manifest_dir, root_dir=None):
|
36 |
+
if root_dir is None:
|
37 |
+
root_dir = "."
|
38 |
+
self.root_dir = Path(root_dir)
|
39 |
+
self.manifest_dir = self.root_dir / manifest_dir
|
40 |
+
self.text_field = "asr"
|
41 |
+
self.manifest = [json.loads(line.strip()) for line in open(manifest_dir)]
|
42 |
+
|
43 |
+
def __getitem__(self, idx):
|
44 |
+
return {
|
45 |
+
"text": self.manifest[idx][self.text_field],
|
46 |
+
"id": str(self.manifest[idx]["id"]),
|
47 |
+
}
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return len(self.manifest)
|
spiritlm/eval/stsp/few_shot_prompt.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import pandas as pd
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MANIFEST_ROOT
|
14 |
+
from spiritlm.model.spiritlm_model import Spiritlm
|
15 |
+
|
16 |
+
FEW_SHOT_MANIFEST_DIR = STSP_MANIFEST_ROOT / "few_shot"
|
17 |
+
FEW_SHOT_TEMPLATE = "{prompt}{generation}"
|
18 |
+
|
19 |
+
|
20 |
+
def wav_prompt(spiritlm_model: Spiritlm, wav: Union[str, torch.Tensor]) -> str:
|
21 |
+
return spiritlm_model.SPEECH_PROMPT_PREFIX + spiritlm_model.speech_tokenizer(wav)
|
22 |
+
|
23 |
+
|
24 |
+
def text_prompt(spiritlm_model: Spiritlm, text: str) -> str:
|
25 |
+
return spiritlm_model.TEXT_PROMPT_PREFIX + text
|
26 |
+
|
27 |
+
|
28 |
+
def _load_half_wav(wav_path: str, load_first_half: bool) -> torch.Tensor:
|
29 |
+
wav_path = STSP_DATA_ROOT / wav_path
|
30 |
+
wav = torchaudio.load(wav_path)[0].squeeze(0)
|
31 |
+
size = wav.size()[0]
|
32 |
+
half_size = size // 2
|
33 |
+
if load_first_half:
|
34 |
+
wav = wav[:half_size]
|
35 |
+
else:
|
36 |
+
wav = wav[half_size:]
|
37 |
+
return wav
|
38 |
+
|
39 |
+
|
40 |
+
def build_few_shot_prompt(
|
41 |
+
spiritlm_model: Spiritlm,
|
42 |
+
input_output: str,
|
43 |
+
n_shots: int = 3,
|
44 |
+
) -> str:
|
45 |
+
"""
|
46 |
+
Build the few-shot prompt by simply concatenating a set of examples.
|
47 |
+
|
48 |
+
E.g., a 3-shots T->S prompt would like this:
|
49 |
+
"[Text]text1[Speech]speech_tokens1\n[Text]text2[Speech]speech_tokens2\n[Text]text3[Speech]speech_tokens3\n"
|
50 |
+
"""
|
51 |
+
manifset_file_mapping = {
|
52 |
+
"text_text": "t2t",
|
53 |
+
"speech_text": "s2t",
|
54 |
+
"text_speech": "t2s",
|
55 |
+
"speech_speech": "s2s",
|
56 |
+
}
|
57 |
+
manifest_path = (
|
58 |
+
FEW_SHOT_MANIFEST_DIR / f"{manifset_file_mapping[input_output]}.jsonl"
|
59 |
+
)
|
60 |
+
df = pd.read_json(manifest_path, lines=True)
|
61 |
+
assert n_shots <= len(df)
|
62 |
+
|
63 |
+
# ensure a balanced sampels for each sentiment
|
64 |
+
nb_samples_per_sentiment = math.ceil(n_shots / 3)
|
65 |
+
df = df.groupby("sentiment").sample(n=nb_samples_per_sentiment)
|
66 |
+
|
67 |
+
prompts = []
|
68 |
+
for _, row in df.iterrows():
|
69 |
+
prompt = row["prompt"]
|
70 |
+
generation = row["generation"]
|
71 |
+
if input_output == "text_text":
|
72 |
+
prompt = FEW_SHOT_TEMPLATE.format(
|
73 |
+
prompt=text_prompt(spiritlm_model, prompt),
|
74 |
+
generation=text_prompt(spiritlm_model, generation),
|
75 |
+
)
|
76 |
+
elif input_output == "text_speech":
|
77 |
+
prompt = FEW_SHOT_TEMPLATE.format(
|
78 |
+
prompt=text_prompt(spiritlm_model, prompt),
|
79 |
+
generation=wav_prompt(
|
80 |
+
spiritlm_model, _load_half_wav(generation, load_first_half=False)
|
81 |
+
),
|
82 |
+
)
|
83 |
+
elif input_output == "speech_text":
|
84 |
+
prompt = FEW_SHOT_TEMPLATE.format(
|
85 |
+
prompt=wav_prompt(
|
86 |
+
spiritlm_model, _load_half_wav(prompt, load_first_half=True)
|
87 |
+
),
|
88 |
+
generation=text_prompt(spiritlm_model, generation),
|
89 |
+
)
|
90 |
+
elif input_output == "speech_speech":
|
91 |
+
prompt = FEW_SHOT_TEMPLATE.format(
|
92 |
+
prompt=wav_prompt(
|
93 |
+
spiritlm_model, _load_half_wav(prompt, load_first_half=True)
|
94 |
+
),
|
95 |
+
generation=wav_prompt(
|
96 |
+
spiritlm_model, _load_half_wav(generation, load_first_half=False)
|
97 |
+
),
|
98 |
+
)
|
99 |
+
prompts.append(prompt)
|
100 |
+
print(f"prompts: {prompts}")
|
101 |
+
return "\n".join(prompts) + "\n"
|
spiritlm/eval/stsp/predict_stsp.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Usage example:
|
9 |
+
|
10 |
+
cd {SPIRITLM ROOT FOLDER}
|
11 |
+
export PYTHONPATH=.
|
12 |
+
|
13 |
+
# Speech to Text
|
14 |
+
torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_s_t.jsonl --input_output speech_text
|
15 |
+
# Text to Text
|
16 |
+
torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_t_t.jsonl --input_output text_text
|
17 |
+
# Text to Speech#
|
18 |
+
torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred._t_s.jsonl --input_output text_speech
|
19 |
+
# Speech to Speech
|
20 |
+
torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_s_s.jsonl --input_output speech_speech
|
21 |
+
|
22 |
+
"""
|
23 |
+
|
24 |
+
import argparse
|
25 |
+
import json
|
26 |
+
import os
|
27 |
+
import uuid
|
28 |
+
from pathlib import Path
|
29 |
+
from typing import Union
|
30 |
+
|
31 |
+
import torch
|
32 |
+
import torch.distributed as dist
|
33 |
+
import torchaudio
|
34 |
+
from spiritlm.eval.eval_stsp import eval
|
35 |
+
from spiritlm.eval.load_data import SpeechData, TextData
|
36 |
+
from spiritlm.eval.stsp.few_shot_prompt import build_few_shot_prompt
|
37 |
+
from spiritlm.eval.stsp.sentiment_classifiers import (
|
38 |
+
get_text_sentiment_prediction,
|
39 |
+
load_sentiment_classifier,
|
40 |
+
)
|
41 |
+
from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MODEL_ROOT
|
42 |
+
from spiritlm.eval.stsp.utils import (
|
43 |
+
ExpressoEmotionClassifier,
|
44 |
+
load_emotion_classifier,
|
45 |
+
wav2emotion_and_sentiment,
|
46 |
+
)
|
47 |
+
from spiritlm.model.spiritlm_model import (
|
48 |
+
ContentType,
|
49 |
+
GenerationInput,
|
50 |
+
InterleavedOutputs,
|
51 |
+
OutputModality,
|
52 |
+
Spiritlm,
|
53 |
+
)
|
54 |
+
from torch.utils.data.distributed import DistributedSampler
|
55 |
+
from tqdm import tqdm
|
56 |
+
from transformers import AutoModelForSequenceClassification, GenerationConfig, set_seed
|
57 |
+
|
58 |
+
SPEECH_CLASSIFIER = STSP_MODEL_ROOT / "speech_classifier"
|
59 |
+
TEXT_CLASSIFIER = STSP_MODEL_ROOT / "text_classifier"
|
60 |
+
|
61 |
+
NB_RETRIES = 3
|
62 |
+
|
63 |
+
|
64 |
+
def get_eval_classifier(args):
|
65 |
+
if args.input_output.endswith("speech"):
|
66 |
+
return load_emotion_classifier(str(SPEECH_CLASSIFIER))
|
67 |
+
elif args.input_output.endswith("text"):
|
68 |
+
return load_sentiment_classifier(str(TEXT_CLASSIFIER))
|
69 |
+
else:
|
70 |
+
raise (Exception(f"{args.input_output} not supported"))
|
71 |
+
|
72 |
+
|
73 |
+
def get_sentiment(
|
74 |
+
input_output,
|
75 |
+
generation,
|
76 |
+
classifer: Union[AutoModelForSequenceClassification, ExpressoEmotionClassifier],
|
77 |
+
):
|
78 |
+
if input_output.endswith("speech"):
|
79 |
+
_, pred_sentiment = wav2emotion_and_sentiment(generation, classifer)
|
80 |
+
elif input_output.endswith("text"):
|
81 |
+
_, pred_sentiment = get_text_sentiment_prediction(generation, classifer)
|
82 |
+
return pred_sentiment
|
83 |
+
|
84 |
+
|
85 |
+
def write_jsonl(dir: str, predictions: dict):
|
86 |
+
Path(dir).parent.mkdir(exist_ok=True, parents=True)
|
87 |
+
with open(dir, "w") as f:
|
88 |
+
for id, result_dict in predictions.items():
|
89 |
+
record = {"id": id, **result_dict}
|
90 |
+
json_string = json.dumps(record)
|
91 |
+
f.write(json_string + "\n") # Add a newline to separate JSON objects
|
92 |
+
print(f"{dir} written")
|
93 |
+
|
94 |
+
|
95 |
+
def write_wav(
|
96 |
+
wav,
|
97 |
+
save_dir: Path,
|
98 |
+
sample_rate: int = 16_000,
|
99 |
+
) -> str:
|
100 |
+
"""Save wav under `save_dir` with a random name and return the full path."""
|
101 |
+
save_dir.mkdir(exist_ok=True, parents=True)
|
102 |
+
random_path = save_dir / (str(uuid.uuid4()) + ".wav")
|
103 |
+
torchaudio.save(
|
104 |
+
random_path, torch.from_numpy(wav).unsqueeze(0), sample_rate=sample_rate
|
105 |
+
)
|
106 |
+
return str(random_path)
|
107 |
+
|
108 |
+
|
109 |
+
def run(args):
|
110 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
111 |
+
world_rank = int(os.environ["RANK"])
|
112 |
+
print(
|
113 |
+
f"Running distributed inference with world_size: {world_size}, world_rank: {world_rank}"
|
114 |
+
)
|
115 |
+
dist.init_process_group("nccl", rank=world_rank, world_size=world_size)
|
116 |
+
set_seed(args.seed)
|
117 |
+
spiritlm_model = Spiritlm(args.model)
|
118 |
+
evaluation_classifier = get_eval_classifier(args)
|
119 |
+
input_output = args.input_output
|
120 |
+
eval_manifest_path = args.eval_manifest_path
|
121 |
+
write_wav_output = args.write_wav_output
|
122 |
+
|
123 |
+
if args.few_shot > 0:
|
124 |
+
prompt = build_few_shot_prompt(
|
125 |
+
spiritlm_model=spiritlm_model,
|
126 |
+
input_output=args.input_output,
|
127 |
+
n_shots=args.few_shot,
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
prompt = None
|
131 |
+
|
132 |
+
# load
|
133 |
+
if input_output.startswith("speech"):
|
134 |
+
eval_dataset = SpeechData(eval_manifest_path, root_dir=STSP_DATA_ROOT)
|
135 |
+
elif input_output.startswith("text"):
|
136 |
+
eval_dataset = TextData(eval_manifest_path, root_dir=STSP_DATA_ROOT)
|
137 |
+
|
138 |
+
sampler = DistributedSampler(dataset=eval_dataset)
|
139 |
+
loader = torch.utils.data.DataLoader(
|
140 |
+
dataset=eval_dataset,
|
141 |
+
batch_size=1, # large batch size is not supported yet
|
142 |
+
sampler=sampler,
|
143 |
+
num_workers=4,
|
144 |
+
)
|
145 |
+
predictions = {}
|
146 |
+
if input_output.endswith("speech"):
|
147 |
+
output_modality = OutputModality.SPEECH
|
148 |
+
max_new_tokens = 300
|
149 |
+
else:
|
150 |
+
output_modality = OutputModality.TEXT
|
151 |
+
max_new_tokens = 50
|
152 |
+
for _, data in tqdm(
|
153 |
+
enumerate(loader),
|
154 |
+
desc=f"Predict {eval_manifest_path}",
|
155 |
+
total=eval_dataset.__len__() // world_size,
|
156 |
+
):
|
157 |
+
# retry the generation multiple times because sometime it does not generate hubert tokens
|
158 |
+
for i in range(NB_RETRIES):
|
159 |
+
try:
|
160 |
+
out: InterleavedOutputs = spiritlm_model.generate(
|
161 |
+
output_modality=output_modality,
|
162 |
+
interleaved_inputs=[
|
163 |
+
GenerationInput(
|
164 |
+
content=(
|
165 |
+
data["wav"][0]
|
166 |
+
if input_output.startswith("speech")
|
167 |
+
else data["text"][0]
|
168 |
+
), # 0 because of batch size 1
|
169 |
+
content_type=(
|
170 |
+
ContentType.SPEECH
|
171 |
+
if input_output.startswith("speech")
|
172 |
+
else ContentType.TEXT
|
173 |
+
),
|
174 |
+
)
|
175 |
+
],
|
176 |
+
generation_config=GenerationConfig(
|
177 |
+
temperature=0.8,
|
178 |
+
top_p=0.95,
|
179 |
+
max_new_tokens=max_new_tokens,
|
180 |
+
do_sample=True,
|
181 |
+
),
|
182 |
+
prompt=prompt,
|
183 |
+
)
|
184 |
+
except Exception as e:
|
185 |
+
print(f"Got an exception when generating: {e}")
|
186 |
+
if i == NB_RETRIES - 1:
|
187 |
+
raise Exception(f"Failed to generate after {NB_RETRIES}")
|
188 |
+
else:
|
189 |
+
break
|
190 |
+
assert len(out) == 1
|
191 |
+
generated_output = out[0].content
|
192 |
+
detected_sentiment = get_sentiment(
|
193 |
+
input_output, generated_output, evaluation_classifier
|
194 |
+
)
|
195 |
+
if output_modality == OutputModality.TEXT:
|
196 |
+
generation = generated_output
|
197 |
+
elif write_wav_output and output_modality == OutputModality.SPEECH:
|
198 |
+
generation = write_wav(generated_output, Path(write_wav_output))
|
199 |
+
else:
|
200 |
+
generation = None
|
201 |
+
result_dict = {"pred": detected_sentiment}
|
202 |
+
if generation is not None:
|
203 |
+
result_dict["generation"] = generation
|
204 |
+
predictions[str(data["id"][0])] = result_dict
|
205 |
+
|
206 |
+
if args.eval:
|
207 |
+
gathered_predictions = [None for _ in range(world_size)]
|
208 |
+
dist.gather_object(
|
209 |
+
predictions, gathered_predictions if world_rank == 0 else None, dst=0
|
210 |
+
)
|
211 |
+
if world_rank == 0:
|
212 |
+
all_predictions = {k: v for d in gathered_predictions for k, v in d.items()}
|
213 |
+
eval(
|
214 |
+
eval_manifest_path,
|
215 |
+
{k: v["pred"] for k, v in all_predictions.items()},
|
216 |
+
info_data=f"{eval_manifest_path}, input-output {input_output}",
|
217 |
+
label="sentiment",
|
218 |
+
)
|
219 |
+
|
220 |
+
if args.write_pred is not None and world_rank == 0:
|
221 |
+
write_jsonl(args.write_pred, all_predictions)
|
222 |
+
|
223 |
+
|
224 |
+
def setup_env():
|
225 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
parser = argparse.ArgumentParser()
|
230 |
+
parser.add_argument(
|
231 |
+
"--eval_manifest_path", # data/examples/ref.jsonl
|
232 |
+
type=str,
|
233 |
+
help="Path to reference record",
|
234 |
+
required=True,
|
235 |
+
)
|
236 |
+
|
237 |
+
parser.add_argument(
|
238 |
+
"--data_root_dir", # data/stsp_data
|
239 |
+
type=str,
|
240 |
+
help=f"Path to root data folder, default to {str(STSP_DATA_ROOT)}",
|
241 |
+
default=str(STSP_DATA_ROOT),
|
242 |
+
required=False,
|
243 |
+
)
|
244 |
+
|
245 |
+
parser.add_argument(
|
246 |
+
"--model",
|
247 |
+
type=str,
|
248 |
+
default="spirit-lm-expressive-7b",
|
249 |
+
help="Model name (spirit-lm-base-7b or spirit-lm-expressive-7b) or path to model",
|
250 |
+
required=False,
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--few_shot",
|
254 |
+
type=int,
|
255 |
+
default=0,
|
256 |
+
help="Number of few shot examples, 3/6/9",
|
257 |
+
required=False,
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--input_output",
|
261 |
+
type=str,
|
262 |
+
default="speech_speech",
|
263 |
+
help="speech_speech speech_text text_speech text_text",
|
264 |
+
required=False,
|
265 |
+
)
|
266 |
+
parser.add_argument(
|
267 |
+
"--eval_type",
|
268 |
+
type=str,
|
269 |
+
default="emotion",
|
270 |
+
required=False,
|
271 |
+
)
|
272 |
+
parser.add_argument(
|
273 |
+
"--write_pred",
|
274 |
+
type=str,
|
275 |
+
default=None,
|
276 |
+
help="Path to save the predictions output",
|
277 |
+
required=False,
|
278 |
+
)
|
279 |
+
parser.add_argument(
|
280 |
+
"--write_wav_output",
|
281 |
+
type=str,
|
282 |
+
default=None,
|
283 |
+
help="Path to save the generated audio if the output is speech",
|
284 |
+
required=False,
|
285 |
+
)
|
286 |
+
parser.add_argument(
|
287 |
+
"--eval",
|
288 |
+
default=False,
|
289 |
+
action="store_true",
|
290 |
+
)
|
291 |
+
parser.add_argument(
|
292 |
+
"--seed",
|
293 |
+
default=0,
|
294 |
+
type=int,
|
295 |
+
)
|
296 |
+
|
297 |
+
args = parser.parse_args()
|
298 |
+
setup_env()
|
299 |
+
run(args)
|
spiritlm/eval/stsp/sanity_check_download.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
|
9 |
+
from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MANIFEST_ROOT
|
10 |
+
|
11 |
+
|
12 |
+
def check_all_datasets():
|
13 |
+
for dataset_manifset in STSP_MANIFEST_ROOT.glob("**/*jsonl"):
|
14 |
+
records_checked = 0
|
15 |
+
print(f"dataset_manifset: {dataset_manifset}")
|
16 |
+
with dataset_manifset.open() as f:
|
17 |
+
for record in f:
|
18 |
+
record = json.loads(record)
|
19 |
+
for wav_key in ["wav_path", "prompt", "generation"]:
|
20 |
+
if wav_key in record and record[wav_key].endswith(".wav"):
|
21 |
+
wav_path = STSP_DATA_ROOT / record[wav_key]
|
22 |
+
assert (
|
23 |
+
wav_path.is_file()
|
24 |
+
), f"Record {record[wav_key]} not found in {str(wav_path)} and listed in {dataset_manifset}"
|
25 |
+
records_checked += 1
|
26 |
+
print(f"{records_checked} records checked for {dataset_manifset.stem} split")
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
check_all_datasets()
|
spiritlm/eval/stsp/sentiment_classifiers.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Any, Dict, List, Tuple
|
8 |
+
|
9 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
10 |
+
|
11 |
+
|
12 |
+
def pred_to_label(
|
13 |
+
sentiment_prediction_scores: List[List[Dict[str, Any]]],
|
14 |
+
) -> Tuple[str, float]:
|
15 |
+
if isinstance(sentiment_prediction_scores[0], list):
|
16 |
+
sentiment_prediction_scores = sentiment_prediction_scores[0]
|
17 |
+
item_with_max_score = max(
|
18 |
+
sentiment_prediction_scores, key=lambda _dict: _dict["score"]
|
19 |
+
)
|
20 |
+
score = item_with_max_score["score"]
|
21 |
+
return score, item_with_max_score["label"].lower()
|
22 |
+
|
23 |
+
|
24 |
+
def get_text_sentiment_prediction(text: str, sentiment_classifier) -> Tuple[str, float]:
|
25 |
+
return pred_to_label(sentiment_classifier(text))
|
26 |
+
|
27 |
+
|
28 |
+
def load_sentiment_classifier(model_dir: str):
|
29 |
+
classifier = pipeline(
|
30 |
+
task="text-classification",
|
31 |
+
model=AutoModelForSequenceClassification.from_pretrained(model_dir),
|
32 |
+
tokenizer=AutoTokenizer.from_pretrained(
|
33 |
+
"j-hartmann/sentiment-roberta-large-english-3-classes"
|
34 |
+
),
|
35 |
+
top_k=None,
|
36 |
+
)
|
37 |
+
return classifier
|
spiritlm/eval/stsp/stsp_constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
STSP_ROOT = Path(__file__).parents[3] / "data" / "stsp_data"
|
10 |
+
STSP_DATA_ROOT = STSP_ROOT / "data"
|
11 |
+
STSP_MODEL_ROOT = STSP_ROOT / "model"
|
12 |
+
STSP_MANIFEST_ROOT = STSP_ROOT / "manifest"
|
spiritlm/eval/stsp/utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from functools import cache
|
9 |
+
from typing import List, Optional, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
14 |
+
|
15 |
+
EXPRESSO_EMOTION_2_SENTIMENT = {
|
16 |
+
"happy": "positive",
|
17 |
+
"angry": "negative",
|
18 |
+
"sad": "negative",
|
19 |
+
"default": "neutral",
|
20 |
+
}
|
21 |
+
|
22 |
+
EMOTION_2_SENTIMENT = {
|
23 |
+
"happy": "positive",
|
24 |
+
"angry": "negative",
|
25 |
+
"sad": "negative",
|
26 |
+
"default": "neutral",
|
27 |
+
"neutral": "neutral",
|
28 |
+
"amused": "positive",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
@cache
|
33 |
+
def emotions2new_label_names_and_indices(
|
34 |
+
emotions_to_select: Tuple[str],
|
35 |
+
label_names: Tuple[str],
|
36 |
+
) -> Tuple[List[str], List[int]]:
|
37 |
+
emotion2index = {e: i for i, e in enumerate(label_names)}
|
38 |
+
sorted_indices_emotions = sorted(
|
39 |
+
[(emotion2index[emotion], emotion) for emotion in emotions_to_select]
|
40 |
+
)
|
41 |
+
zipped = list(zip(*sorted_indices_emotions))
|
42 |
+
return zipped
|
43 |
+
|
44 |
+
|
45 |
+
def expresso_emotion2_sentiment(emotion: str):
|
46 |
+
return EXPRESSO_EMOTION_2_SENTIMENT[emotion]
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class ExpressoEmotionClassifier:
|
51 |
+
feature_extractor: AutoFeatureExtractor
|
52 |
+
model: AutoModelForAudioClassification
|
53 |
+
label_names: List[str]
|
54 |
+
|
55 |
+
|
56 |
+
def load_emotion_classifier(checkpoint_path: str) -> ExpressoEmotionClassifier:
|
57 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint_path)
|
58 |
+
model = (
|
59 |
+
AutoModelForAudioClassification.from_pretrained(checkpoint_path).cuda().eval()
|
60 |
+
)
|
61 |
+
label_names = [model.config.id2label[i] for i in range(model.config.num_labels)]
|
62 |
+
print(f"Classification model loaded from {checkpoint_path} !")
|
63 |
+
return ExpressoEmotionClassifier(feature_extractor, model, label_names)
|
64 |
+
|
65 |
+
|
66 |
+
@torch.inference_mode()
|
67 |
+
def predict_audio(
|
68 |
+
audio,
|
69 |
+
expresso_emotion_classifier: ExpressoEmotionClassifier,
|
70 |
+
emotions_to_predict: Optional[List[str]] = None,
|
71 |
+
):
|
72 |
+
if isinstance(audio, str):
|
73 |
+
speech, _ = torchaudio.load(audio)
|
74 |
+
resampler = torchaudio.transforms.Resample(
|
75 |
+
expresso_emotion_classifier.feature_extractor.sampling_rate
|
76 |
+
)
|
77 |
+
speech = resampler(speech).squeeze().numpy()
|
78 |
+
else:
|
79 |
+
speech = audio
|
80 |
+
|
81 |
+
features = expresso_emotion_classifier.feature_extractor(
|
82 |
+
speech,
|
83 |
+
sampling_rate=expresso_emotion_classifier.feature_extractor.sampling_rate,
|
84 |
+
return_tensors="pt",
|
85 |
+
)
|
86 |
+
features["input_values"] = features["input_values"].cuda()
|
87 |
+
|
88 |
+
logits = expresso_emotion_classifier.model(**features).logits
|
89 |
+
if emotions_to_predict is not None:
|
90 |
+
(indices, label_names) = emotions2new_label_names_and_indices(
|
91 |
+
tuple(emotions_to_predict), tuple(expresso_emotion_classifier.label_names)
|
92 |
+
)
|
93 |
+
logits = logits[:, indices]
|
94 |
+
else:
|
95 |
+
label_names = expresso_emotion_classifier.label_names
|
96 |
+
pred_id = torch.argmax(logits, dim=-1)[0].item()
|
97 |
+
|
98 |
+
return label_names[pred_id], logits.detach().cpu().numpy()
|
99 |
+
|
100 |
+
|
101 |
+
def wav2emotion(
|
102 |
+
wav,
|
103 |
+
expresso_emotion_classifier: ExpressoEmotionClassifier,
|
104 |
+
emotions_to_predict: Optional[List[str]] = None,
|
105 |
+
) -> str:
|
106 |
+
label_logits = predict_audio(
|
107 |
+
audio=wav,
|
108 |
+
expresso_emotion_classifier=expresso_emotion_classifier,
|
109 |
+
emotions_to_predict=emotions_to_predict,
|
110 |
+
)
|
111 |
+
pred_emotion = label_logits[0]
|
112 |
+
return pred_emotion
|
113 |
+
|
114 |
+
|
115 |
+
def wav2emotion_and_sentiment(
|
116 |
+
wav,
|
117 |
+
expresso_emotion_classifier: ExpressoEmotionClassifier,
|
118 |
+
emotions_to_predict: Optional[List[str]] = None,
|
119 |
+
) -> Tuple[str, str]:
|
120 |
+
pred_emotion = wav2emotion(wav, expresso_emotion_classifier, emotions_to_predict)
|
121 |
+
mapped_sentiment = expresso_emotion2_sentiment(pred_emotion)
|
122 |
+
return pred_emotion, mapped_sentiment
|
spiritlm/eval/utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the FAIR Noncommercial Research License
|
5 |
+
# found in the LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torchaudio
|
8 |
+
from spiritlm.model.spiritlm_model import Spiritlm
|
9 |
+
|
10 |
+
|
11 |
+
def wav_prompt(spiritlm_model: Spiritlm, wav_path: str) -> str:
|
12 |
+
wav = torchaudio.load(wav_path)[0].squeeze(0)
|
13 |
+
return spiritlm_model.SPEECH_PROMPT_PREFIX + spiritlm_model.speech_tokenizer(wav)
|
14 |
+
|
15 |
+
|
16 |
+
def text_prompt(spiritlm_model: Spiritlm, text: str) -> str:
|
17 |
+
return spiritlm_model.TEXT_PROMPT_PREFIX + text
|
spiritlm/model/README.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model for Spirit LM
|
2 |
+
This repo includes the Spirit LM model wrapper.
|
3 |
+
|
4 |
+
## Usage examples
|
5 |
+
|
6 |
+
### Model Loading
|
7 |
+
```python
|
8 |
+
from spiritlm.model.spiritlm_model import Spiritlm
|
9 |
+
|
10 |
+
# Spirit LM Base 7B
|
11 |
+
spirit_lm = Spiritlm("spirit-lm-base-7b")
|
12 |
+
|
13 |
+
# Spirit LM Expressive 7B
|
14 |
+
spirit_lm = Spiritlm("spirit-lm-expressive-7b")
|
15 |
+
```
|
16 |
+
|
17 |
+
### Generation examples
|
18 |
+
```python
|
19 |
+
from spiritlm.model.spiritlm_model import OutputModality, GenerationInput, ContentType
|
20 |
+
from transformers import GenerationConfig
|
21 |
+
|
22 |
+
# Generate only text
|
23 |
+
spirit_lm.generate(
|
24 |
+
output_modality=OutputModality.TEXT,
|
25 |
+
interleaved_inputs=[
|
26 |
+
GenerationInput(
|
27 |
+
content="The largest country in the world is",
|
28 |
+
content_type=ContentType.TEXT,
|
29 |
+
)
|
30 |
+
],
|
31 |
+
generation_config=GenerationConfig(
|
32 |
+
temperature=0.9,
|
33 |
+
top_p=0.95,
|
34 |
+
max_new_tokens=50,
|
35 |
+
do_sample=True,
|
36 |
+
),
|
37 |
+
)
|
38 |
+
|
39 |
+
# Expected output format:
|
40 |
+
# [GenerationOuput(content='Russia, with an area of ...', content_type=<ContentType.TEXT: 'TEXT'>)]
|
41 |
+
|
42 |
+
# Generate only speech
|
43 |
+
spirit_lm.generate(
|
44 |
+
output_modality=OutputModality.SPEECH,
|
45 |
+
interleaved_inputs=[
|
46 |
+
GenerationInput(
|
47 |
+
content="examples/audio/7143-88743-0029.flac",
|
48 |
+
content_type=ContentType.SPEECH,
|
49 |
+
)
|
50 |
+
],
|
51 |
+
generation_config=GenerationConfig(
|
52 |
+
temperature=0.9,
|
53 |
+
top_p=0.95,
|
54 |
+
max_new_tokens=200,
|
55 |
+
do_sample=True,
|
56 |
+
),
|
57 |
+
)
|
58 |
+
|
59 |
+
# Expected output format:
|
60 |
+
# [GenerationOuput(content=array([ 3.6673620e-05, 2.6468514e-04, 1.0735081e-03, ...,], dtype=float32), content_type=<ContentType.SPEECH: 'SPEECH'>)]
|
61 |
+
|
62 |
+
|
63 |
+
# Arbitrary generation
|
64 |
+
spirit_lm.generate(
|
65 |
+
output_modality=OutputModality.ARBITRARY,
|
66 |
+
interleaved_inputs=[
|
67 |
+
GenerationInput(
|
68 |
+
content="examples/audio/7143-88743-0029.flac",
|
69 |
+
content_type=ContentType.SPEECH,
|
70 |
+
)
|
71 |
+
],
|
72 |
+
generation_config=GenerationConfig(
|
73 |
+
temperature=0.9,
|
74 |
+
top_p=0.95,
|
75 |
+
max_new_tokens=200,
|
76 |
+
do_sample=True,
|
77 |
+
),
|
78 |
+
)
|
79 |
+
# Expected output format is a list of GenerationOuput where content type could be `ContentType.TEXT' or `ContentType.SPEECH`:
|
80 |
+
# [GenerationOuput(content='xxx', content_type=<ContentType.TEXT: 'TEXT'>), GenerationOuput(content=array([ 0.00553902, -0.03210586, ... ], dtype=float32), content_type=<ContentType.SPEECH: 'SPEECH'>), GenerationOuput(content='yyy', content_type=<ContentType.TEXT: 'TEXT'>), GenerationOuput(content=array([0.04051103, 0.03596291, 0.03381396, ..., 0.05103811, 0.05429034, ..,,], dtype=float32), content_type=<ContentType.SPEECH: 'SPEECH'>)]
|
81 |
+
```
|
82 |
+
See more examples with other modalites in [examples/speech_generation/spirit_model.ipynb](../../examples/speech_generation/spirit_model.ipynb).
|